Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3a9bfd07
Commit
3a9bfd07
authored
Jul 03, 2023
by
Tri Dao
Browse files
[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)
parent
e8a0b4ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
38 deletions
+37
-38
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
+34
-34
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+3
-4
No files found.
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
View file @
3a9bfd07
...
...
@@ -1549,7 +1549,7 @@ inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
...
...
@@ -1558,7 +1558,7 @@ inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
...
...
@@ -1570,9 +1570,9 @@ inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
}
...
...
@@ -1584,10 +1584,10 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
Float4_
&
k_
=
*
reinterpret_cast
<
Float4_
*>
(
&
k
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
k_
.
x
=
rotary_embedding_transform
(
k_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
k_
.
y
=
rotary_embedding_transform
(
k_
.
y
,
coef1
);
}
...
...
@@ -1597,7 +1597,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embe
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
...
...
@@ -1606,7 +1606,7 @@ inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid,
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
...
...
@@ -1616,9 +1616,9 @@ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_d
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
...
...
@@ -1627,10 +1627,10 @@ inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int r
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
...
...
@@ -1640,13 +1640,13 @@ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_d
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
...
...
@@ -1655,16 +1655,16 @@ inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int r
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
...
...
@@ -1675,7 +1675,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
...
...
@@ -1684,7 +1684,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162&
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
...
...
@@ -1694,9 +1694,9 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embe
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
...
...
@@ -1705,10 +1705,10 @@ inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid,
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
...
...
@@ -1718,13 +1718,13 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embe
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
...
...
@@ -1733,16 +1733,16 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
,
rotary_sin
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
...
...
csrc/ft_attention/ft_attention.cpp
View file @
3a9bfd07
...
...
@@ -160,16 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
if
(
rotary_cos_
.
has_value
())
{
auto
rotary_cos
=
rotary_cos_
.
value
();
CHECK_DEVICE
(
rotary_cos
);
int
rotary_seqlen
=
rotary_cos
.
size
(
0
);
rotary_embedding_dim
=
rotary_cos
.
size
(
1
)
*
2
;
CHECK_SHAPE
(
rotary_cos
,
rotary_seqlen
,
rotary_embedding_dim
/
2
);
rotary_embedding_dim
=
rotary_cos
.
size
(
0
)
*
2
;
CHECK_SHAPE
(
rotary_cos
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_cos
);
TORCH_CHECK
(
rotary_cos
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
rotary_sin_
.
has_value
());
auto
rotary_sin
=
rotary_sin_
.
value
();
CHECK_DEVICE
(
rotary_sin
);
CHECK_SHAPE
(
rotary_cos
,
rotary_seqlen
,
rotary_embedding_dim
/
2
);
CHECK_SHAPE
(
rotary_cos
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_sin
);
TORCH_CHECK
(
rotary_sin
.
scalar_type
()
==
input_type
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment