Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
2800efc7
"vscode:/vscode.git/clone" did not exist on "070e0ca1081d45860d7b3483e7562eb36be6ae38"
Commit
2800efc7
authored
Jul 06, 2023
by
Tri Dao
Browse files
[FT] rotary_cos/sin should have batch_size dimension
parent
d2f4324f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+12
-4
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+3
-3
No files found.
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
View file @
2800efc7
...
...
@@ -1065,14 +1065,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if
(
params
.
rotary_cos
==
nullptr
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
,
params
.
rotary_sin
);
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
+
bi
*
params
.
rotary_embedding_dim
/
2
,
params
.
rotary_sin
+
bi
*
params
.
rotary_embedding_dim
/
2
);
}
}
else
{
if
(
params
.
rotary_cos
==
nullptr
)
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
,
params
.
rotary_sin
);
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
+
bi
*
params
.
rotary_embedding_dim
/
2
,
params
.
rotary_sin
+
bi
*
params
.
rotary_embedding_dim
/
2
);
}
}
}
...
...
@@ -1112,7 +1116,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
,
params
.
rotary_sin
);
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
+
bi
*
params
.
rotary_embedding_dim
/
2
,
params
.
rotary_sin
+
bi
*
params
.
rotary_embedding_dim
/
2
);
}
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
...
...
@@ -1123,7 +1129,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_cos
,
params
.
rotary_sin
);
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_cos
+
bi
*
params
.
rotary_embedding_dim
/
2
,
params
.
rotary_sin
+
bi
*
params
.
rotary_embedding_dim
/
2
);
}
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
...
...
csrc/ft_attention/ft_attention.cpp
View file @
2800efc7
...
...
@@ -160,15 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
if
(
rotary_cos_
.
has_value
())
{
auto
rotary_cos
=
rotary_cos_
.
value
();
CHECK_DEVICE
(
rotary_cos
);
rotary_embedding_dim
=
rotary_cos
.
size
(
0
)
*
2
;
CHECK_SHAPE
(
rotary_cos
,
rotary_embedding_dim
/
2
);
rotary_embedding_dim
=
rotary_cos
.
size
(
-
1
)
*
2
;
CHECK_SHAPE
(
rotary_cos
,
batch_size
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_cos
);
TORCH_CHECK
(
rotary_cos
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
rotary_sin_
.
has_value
());
auto
rotary_sin
=
rotary_sin_
.
value
();
CHECK_DEVICE
(
rotary_sin
);
CHECK_SHAPE
(
rotary_cos
,
rotary_embedding_dim
/
2
);
CHECK_SHAPE
(
rotary_cos
,
batch_size
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_sin
);
TORCH_CHECK
(
rotary_sin
.
scalar_type
()
==
input_type
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment