Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
48bc6eac
Commit
48bc6eac
authored
May 30, 2023
by
Tri Dao
Browse files
[Gen] Add rotary base as an argument to FT attention kernel
parent
7c766b1b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
84 additions
and
72 deletions
+84
-72
csrc/ft_attention/decoder_masked_multihead_attention.h
csrc/ft_attention/decoder_masked_multihead_attention.h
+1
-0
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+4
-4
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
+54
-54
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+5
-2
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+2
-1
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+3
-1
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+13
-9
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+2
-1
No files found.
csrc/ft_attention/decoder_masked_multihead_attention.h
View file @
48bc6eac
...
...
@@ -84,6 +84,7 @@ struct Multihead_attention_params_base {
// The per-head latent space reserved for rotary embeddings.
int
rotary_embedding_dim
=
0
;
bool
neox_rotary_style
=
false
;
float
rotary_base
=
0.0
f
;
// The maximum length of input sentences.
int
max_input_length
=
0
;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
...
...
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
View file @
48bc6eac
...
...
@@ -1061,10 +1061,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
handle_kv
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
);
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
);
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
}
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
...
...
@@ -1099,13 +1099,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
);
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
);
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
}
...
...
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
View file @
48bc6eac
...
...
@@ -1272,9 +1272,9 @@ inline __device__ void zero(T& dst)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
float
t_step
)
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
float
t_step
,
const
float
base
)
{
const
float
inv_freq
=
t_step
/
pow
(
10000.0
f
,
zid
/
(
float
)
rot_embed_dim
);
const
float
inv_freq
=
t_step
/
pow
(
base
,
zid
/
(
float
)
rot_embed_dim
);
return
{
cos
(
inv_freq
),
sin
(
inv_freq
)};
}
...
...
@@ -1302,49 +1302,49 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162
}
#endif
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
int
zid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
float
&
k
,
int
zid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
float
&
k
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
float2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
float2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
float4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
float4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
...
...
@@ -1352,166 +1352,166 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
Float4_
&
k_
=
*
reinterpret_cast
<
Float4_
*>
(
&
k
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
k_
.
x
=
rotary_embedding_transform
(
k_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
k_
.
y
=
rotary_embedding_transform
(
k_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
uint32_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
uint32_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
uint2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
uint2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
uint4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
uint4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#ifdef ENABLE_BF16
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
bf16_4_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
bf16_4_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
bf16_8_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
)
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
bf16_8_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
);
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
...
...
csrc/ft_attention/ft_attention.cpp
View file @
48bc6eac
...
...
@@ -54,6 +54,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
const
size_t
headdim
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
T
*
q_ptr
,
...
...
@@ -82,6 +83,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
params
.
num_heads
=
nheads
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_base
=
rotary_base
;
params
.
neox_rotary_style
=
neox_rotary_style
;
params
.
timestep
=
timestep
;
params
.
inv_sqrt_dh
=
1.
f
/
sqrt
(
float
(
headdim
));
...
...
@@ -107,6 +109,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
const
int
timestep
,
const
int
rotary_embedding_dim
=
0
,
const
float
rotary_base
=
10000.0
f
,
const
bool
neox_rotary_style
=
true
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
int
batch_size
=
v_cache
.
size
(
0
);
...
...
@@ -144,7 +147,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
neox_rotary_style
,
q
.
stride
(
0
),
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
...
...
@@ -163,5 +166,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
flash_attn/layers/rotary.py
View file @
48bc6eac
...
...
@@ -169,12 +169,13 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
interleaved
=
False
,
scale_base
=
None
,
device
=
None
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000
.0
,
interleaved
=
False
,
scale_base
=
None
,
device
=
None
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
"""
super
().
__init__
()
self
.
base
=
float
(
base
)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
dim
))
...
...
flash_attn/models/gpt.py
View file @
48bc6eac
...
...
@@ -75,6 +75,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
qkv_proj_bias
=
getattr
(
config
,
'qkv_proj_bias'
,
True
)
out_proj_bias
=
getattr
(
config
,
'out_proj_bias'
,
True
)
rotary_emb_dim
=
int
(
getattr
(
config
,
'rotary_emb_fraction'
,
0.0
)
*
head_dim
)
rotary_emb_base
=
getattr
(
config
,
'rotary_emb_base'
,
10000.0
)
rotary_emb_scale_base
=
getattr
(
config
,
'rotary_emb_scale_base'
,
None
)
rotary_emb_interleaved
=
getattr
(
config
,
'rotary_emb_interleaved'
,
False
)
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
...
...
@@ -91,7 +92,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_base
=
rotary_emb_base
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
rotary_emb_interleaved
=
rotary_emb_interleaved
,
use_flash_attn
=
use_flash_attn
,
**
serial_kwargs
,
**
parallel_kwargs
,
**
factory_kwargs
)
...
...
flash_attn/modules/mha.py
View file @
48bc6eac
...
...
@@ -350,9 +350,9 @@ class MHA(nn.Module):
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_
scale_base
=
None
,
rotary_emb_interleaved
=
Fals
e
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
rotary_emb_dim
=
0
,
rotary_emb_
base
=
10000.0
,
rotary_emb_scale_base
=
Non
e
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
"""
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
...
...
@@ -377,7 +377,8 @@ class MHA(nn.Module):
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
if
fused_bias_fc
and
FusedDense
is
None
:
...
...
@@ -511,11 +512,12 @@ class MHA(nn.Module):
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
...
...
@@ -555,8 +557,8 @@ class ParallelMHA(nn.Module):
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_
scale_base
=
None
,
rotary_emb_interleaved
=
Fals
e
,
use_flash_attn
=
False
,
checkpointing
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_
base
=
10000.0
,
rotary_emb_scale_base
=
Non
e
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
...
...
@@ -573,7 +575,8 @@ class ParallelMHA(nn.Module):
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
...
...
@@ -631,11 +634,12 @@ class ParallelMHA(nn.Module):
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
inference_params
.
sequence_len_off
se
t
,
self
.
rotary_emb_dim
,
rotary_emb_ba
se
,
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
...
...
tests/models/test_gpt_generation.py
View file @
48bc6eac
...
...
@@ -36,7 +36,8 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
config
=
GPT2Config
.
from_pretrained
(
model_name
)
if
rotary
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
64
config
.
rotary_emb_fraction
=
0.5
config
.
rotary_emb_base
=
24000
config
.
residual_in_fp32
=
True
if
optimized
:
config
.
use_flash_attn
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment