Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
62e98144
Commit
62e98144
authored
Jul 02, 2023
by
Tri Dao
Browse files
[Rotary] Make sure frequency calculation is in fp32
parent
9818f85f
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
393 additions
and
42 deletions
+393
-42
csrc/ft_attention/decoder_masked_multihead_attention.cu
csrc/ft_attention/decoder_masked_multihead_attention.cu
+1
-1
csrc/ft_attention/decoder_masked_multihead_attention.h
csrc/ft_attention/decoder_masked_multihead_attention.h
+6
-0
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+26
-7
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
+236
-5
csrc/ft_attention/ft_attention.cpp
csrc/ft_attention/ft_attention.cpp
+59
-4
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+48
-19
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+14
-4
tests/models/test_llama.py
tests/models/test_llama.py
+3
-2
No files found.
csrc/ft_attention/decoder_masked_multihead_attention.cu
View file @
62e98144
...
@@ -34,7 +34,7 @@
...
@@ -34,7 +34,7 @@
if (smem_sz >= 48 * 1024) { \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
} \
dim3 grid(params.n
um_heads, params.batch_size);
\
dim3 grid(params.n
nz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size);
\
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/ft_attention/decoder_masked_multihead_attention.h
View file @
62e98144
...
@@ -113,6 +113,12 @@ struct Multihead_attention_params_base {
...
@@ -113,6 +113,12 @@ struct Multihead_attention_params_base {
const
float
*
qkv_scale_out
=
nullptr
;
const
float
*
qkv_scale_out
=
nullptr
;
const
float
*
attention_out_scale
=
nullptr
;
const
float
*
attention_out_scale
=
nullptr
;
int
int8_mode
=
0
;
int
int8_mode
=
0
;
const
T
*
rotary_cos
=
nullptr
;
const
T
*
rotary_sin
=
nullptr
;
const
int
*
nnz_head_idx
=
nullptr
;
int
nnz_heads
=
0
;
};
};
template
<
typename
T
,
bool
CROSS_ATTENTION
>
template
<
typename
T
,
bool
CROSS_ATTENTION
>
...
...
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp
View file @
62e98144
...
@@ -941,7 +941,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
...
@@ -941,7 +941,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The "beam-aware" batch idx
// The "beam-aware" batch idx
const
int
bbi
=
bi
/
params
.
beam_width
;
const
int
bbi
=
bi
/
params
.
beam_width
;
// The head.
// The head.
const
int
hi
=
blockIdx
.
x
;
// const int hi = blockIdx.x;
const
int
hi
=
params
.
nnz_head_idx
==
nullptr
?
blockIdx
.
x
:
params
.
nnz_head_idx
[
blockIdx
.
x
];
// Combine the batch and the head indices.
// Combine the batch and the head indices.
const
int
bhi
=
bi
*
params
.
num_heads
+
hi
;
const
int
bhi
=
bi
*
params
.
num_heads
+
hi
;
// Combine the "beam-aware" batch idx and the head indices.
// Combine the "beam-aware" batch idx and the head indices.
...
@@ -1061,10 +1062,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
...
@@ -1061,10 +1062,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
handle_kv
)
{
if
(
handle_kv
)
{
if
(
params
.
rotary_cos
==
nullptr
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
,
params
.
rotary_sin
);
}
}
}
else
{
else
{
if
(
params
.
rotary_cos
==
nullptr
)
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
,
params
.
rotary_sin
);
}
}
}
}
}
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
...
@@ -1098,14 +1107,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
...
@@ -1098,14 +1107,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if
(
handle_kv
)
{
if
(
handle_kv
)
{
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
if
(
params
.
rotary_cos
==
nullptr
)
{
mmha
::
apply_rotary_embedding
(
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_cos
,
params
.
rotary_sin
);
}
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
}
}
else
{
else
{
if
(
params
.
rotary_cos
==
nullptr
)
{
mmha
::
apply_rotary_embedding
(
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_cos
,
params
.
rotary_sin
);
}
}
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
}
}
...
...
csrc/ft_attention/decoder_masked_multihead_attention_utils.h
View file @
62e98144
...
@@ -1272,10 +1272,10 @@ inline __device__ void zero(T& dst)
...
@@ -1272,10 +1272,10 @@ inline __device__ void zero(T& dst)
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
floa
t
t_step
,
const
float
base
)
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
in
t
t_step
,
const
float
base
)
{
{
const
float
inv_freq
=
t_step
/
pow
(
base
,
zid
/
(
float
)
rot_embed_dim
);
const
float
pos_idx_
inv_freq
=
t_step
/
pow
(
base
,
zid
/
(
float
)
rot_embed_dim
);
return
{
cos
(
inv_freq
),
sin
(
inv_freq
)};
return
{
cos
(
pos_idx_
inv_freq
),
sin
(
pos_idx_
inv_freq
)};
}
}
inline
__device__
float2
rotary_embedding_transform
(
const
float2
v
,
const
float2
coef
)
inline
__device__
float2
rotary_embedding_transform
(
const
float2
v
,
const
float2
coef
)
...
@@ -1447,8 +1447,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
...
@@ -1447,8 +1447,7 @@ inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int ro
q
=
rotary_embedding_transform
(
q
,
coef
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
}
inline
__device__
void
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
return
;
...
@@ -1517,6 +1516,238 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
...
@@ -1517,6 +1516,238 @@ inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid,
}
}
#endif // ENABLE_BF16
#endif // ENABLE_BF16
template
<
typename
T
>
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
t_step
,
const
T
*
rotary_cos
,
const
T
*
rotary_sin
)
{
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
return
{
float
(
rotary_cos
[
zid
/
2
]),
float
(
rotary_sin
[
zid
/
2
])};
}
// fp16 is special because we use uint16_t for reading the data, for backward compatibility.
template
<
>
inline
__device__
float2
rotary_embedding_coefficient
<
uint16_t
>
(
const
int
zid
,
const
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
// zid is the index of the dimension (0, 2, 4, ..., rotary_dim).
// rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2.
return
{
float
(
reinterpret_cast
<
const
__half
*>
(
rotary_cos
)[
zid
/
2
]),
float
(
reinterpret_cast
<
const
__half
*>
(
rotary_sin
)[
zid
/
2
])};
}
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
*
rotary_cos
,
const
float
*
rotary_sin
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
float
&
k
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
*
rotary_cos
,
const
float
*
rotary_sin
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
*
rotary_cos
,
const
float
*
rotary_sin
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
float2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
*
rotary_cos
,
const
float
*
rotary_sin
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
*
rotary_cos
,
const
float
*
rotary_sin
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
float4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
*
rotary_cos
,
const
float
*
rotary_sin
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
Float4_
&
k_
=
*
reinterpret_cast
<
Float4_
*>
(
&
k
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
k_
.
x
=
rotary_embedding_transform
(
k_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
k_
.
y
=
rotary_embedding_transform
(
k_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
uint32_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
uint2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
uint4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
uint16_t
*
rotary_cos
,
const
uint16_t
*
rotary_sin
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#ifdef ENABLE_BF16
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
__nv_bfloat16
*
rotary_cos
,
const
__nv_bfloat16
*
rotary_sin
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
__nv_bfloat16
*
rotary_cos
,
const
__nv_bfloat16
*
rotary_sin
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
__nv_bfloat16
*
rotary_cos
,
const
__nv_bfloat16
*
rotary_sin
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
bf16_4_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
__nv_bfloat16
*
rotary_cos
,
const
__nv_bfloat16
*
rotary_sin
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
__nv_bfloat16
*
rotary_cos
,
const
__nv_bfloat16
*
rotary_sin
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
bf16_8_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
__nv_bfloat16
*
rotary_cos
,
const
__nv_bfloat16
*
rotary_sin
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
t_step
,
rotary_cos
+
t_step
*
rot_embed_dim
/
2
,
rotary_sin
+
t_step
*
rot_embed_dim
/
2
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#endif // ENABLE_BF16
template
<
typename
Vec_T
,
typename
T
>
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
vec_from_smem_transpose
(
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
__device__
__inline__
void
vec_from_smem_transpose
(
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
...
...
csrc/ft_attention/ft_attention.cpp
View file @
62e98144
...
@@ -57,13 +57,17 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
...
@@ -57,13 +57,17 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
const
float
rotary_base
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
const
int
qkv_batch_stride
,
const
int
nnz_heads
,
T
*
q_ptr
,
T
*
q_ptr
,
T
*
k_ptr
,
T
*
k_ptr
,
T
*
v_ptr
,
T
*
v_ptr
,
T
*
k_cache_ptr
,
T
*
k_cache_ptr
,
T
*
v_cache_ptr
,
T
*
v_cache_ptr
,
int
*
length_per_sample
,
int
*
length_per_sample
,
T
*
out_ptr
)
{
T
*
rotary_cos
,
T
*
rotary_sin
,
T
*
out_ptr
,
int
*
nnz_head_idx
)
{
// Reset the parameters
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
q
=
q_ptr
;
params
.
q
=
q_ptr
;
...
@@ -81,6 +85,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
...
@@ -81,6 +85,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
params
.
beam_width
=
1
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
num_heads
=
nheads
;
params
.
num_heads
=
nheads
;
params
.
nnz_heads
=
nnz_heads
;
params
.
hidden_size_per_head
=
headdim
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_base
=
rotary_base
;
params
.
rotary_base
=
rotary_base
;
...
@@ -99,6 +104,9 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
...
@@ -99,6 +104,9 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
params
.
finished
=
nullptr
;
params
.
finished
=
nullptr
;
params
.
memory_length_per_sample
=
nullptr
;
params
.
memory_length_per_sample
=
nullptr
;
params
.
length_per_sample
=
length_per_sample
;
params
.
length_per_sample
=
length_per_sample
;
params
.
rotary_cos
=
rotary_cos
;
params
.
rotary_sin
=
rotary_sin
;
params
.
nnz_head_idx
=
nnz_head_idx
;
}
}
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
...
@@ -107,8 +115,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -107,8 +115,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
torch
::
Tensor
k_cache
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
rotary_cos_
,
c10
::
optional
<
const
torch
::
Tensor
>
rotary_sin_
,
c10
::
optional
<
const
torch
::
Tensor
>
nnz_head_idx_
,
const
int
timestep
,
const
int
timestep
,
const
int
rotary_embedding_dim
=
0
,
int
rotary_embedding_dim
=
0
,
const
float
rotary_base
=
10000.0
f
,
const
float
rotary_base
=
10000.0
f
,
const
bool
neox_rotary_style
=
true
)
{
const
bool
neox_rotary_style
=
true
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
...
@@ -116,6 +127,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -116,6 +127,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
int
nheads
=
v_cache
.
size
(
1
);
int
nheads
=
v_cache
.
size
(
1
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
headdim
=
v_cache
.
size
(
3
);
int
headdim
=
v_cache
.
size
(
3
);
auto
input_type
=
q
.
scalar_type
();
TORCH_CHECK
(
input_type
==
at
::
ScalarType
::
Float
||
input_type
==
at
::
ScalarType
::
Half
||
input_type
==
at
::
ScalarType
::
BFloat16
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads
,
headdim
);
...
@@ -129,6 +143,12 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -129,6 +143,12 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK
(
q
.
stride
(
0
)
==
k
.
stride
(
0
)
&&
q
.
stride
(
0
)
==
v
.
stride
(
0
));
TORCH_CHECK
(
q
.
stride
(
0
)
==
k
.
stride
(
0
)
&&
q
.
stride
(
0
)
==
v
.
stride
(
0
));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
TORCH_CHECK
(
q
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
k
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
v
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
k_cache
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
v_cache
.
scalar_type
()
==
input_type
);
if
(
length_per_sample_
.
has_value
())
{
if
(
length_per_sample_
.
has_value
())
{
auto
length_per_sample
=
length_per_sample_
.
value
();
auto
length_per_sample
=
length_per_sample_
.
value
();
CHECK_DEVICE
(
length_per_sample
);
CHECK_DEVICE
(
length_per_sample
);
...
@@ -137,6 +157,32 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -137,6 +157,32 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
}
}
if
(
rotary_cos_
.
has_value
())
{
auto
rotary_cos
=
rotary_cos_
.
value
();
CHECK_DEVICE
(
rotary_cos
);
int
rotary_seqlen
=
rotary_cos
.
size
(
0
);
rotary_embedding_dim
=
rotary_cos
.
size
(
1
)
*
2
;
CHECK_SHAPE
(
rotary_cos
,
rotary_seqlen
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_cos
);
TORCH_CHECK
(
rotary_cos
.
scalar_type
()
==
input_type
);
TORCH_CHECK
(
rotary_sin_
.
has_value
());
auto
rotary_sin
=
rotary_sin_
.
value
();
CHECK_DEVICE
(
rotary_sin
);
CHECK_SHAPE
(
rotary_cos
,
rotary_seqlen
,
rotary_embedding_dim
/
2
);
CHECK_CONTIGUOUS
(
rotary_sin
);
TORCH_CHECK
(
rotary_sin
.
scalar_type
()
==
input_type
);
}
if
(
nnz_head_idx_
.
has_value
())
{
auto
nnz_head_idx
=
nnz_head_idx_
.
value
();
CHECK_DEVICE
(
nnz_head_idx
);
int
nnz_heads
=
nnz_head_idx
.
size
(
0
);
CHECK_SHAPE
(
nnz_head_idx
,
nnz_heads
);
CHECK_CONTIGUOUS
(
nnz_head_idx
);
TORCH_CHECK
(
nnz_head_idx
.
dtype
()
==
torch
::
kInt32
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
...
@@ -148,6 +194,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -148,6 +194,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
Masked_multihead_attention_params
<
DataType
>
params
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
,
timestep
,
set_params
(
params
,
batch_size
,
nheads
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
nnz_head_idx_
.
has_value
()
?
nnz_head_idx_
.
value
().
size
(
0
)
:
0
,
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
...
@@ -155,7 +202,13 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -155,7 +202,13 @@ torch::Tensor single_query_attention(const torch::Tensor q,
reinterpret_cast
<
DataType
*>
(
v_cache
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v_cache
.
data_ptr
()),
length_per_sample_
.
has_value
()
length_per_sample_
.
has_value
()
?
length_per_sample_
.
value
().
data_ptr
<
int
>
()
:
nullptr
,
?
length_per_sample_
.
value
().
data_ptr
<
int
>
()
:
nullptr
,
reinterpret_cast
<
DataType
*>
(
out
.
data_ptr
()));
rotary_cos_
.
has_value
()
?
reinterpret_cast
<
DataType
*>
(
rotary_cos_
.
value
().
data_ptr
())
:
nullptr
,
rotary_sin_
.
has_value
()
?
reinterpret_cast
<
DataType
*>
(
rotary_sin_
.
value
().
data_ptr
())
:
nullptr
,
reinterpret_cast
<
DataType
*>
(
out
.
data_ptr
()),
nnz_head_idx_
.
has_value
()
?
nnz_head_idx_
.
value
().
data_ptr
<
int
>
()
:
nullptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
masked_multihead_attention
(
params
,
stream
);
masked_multihead_attention
(
params
,
stream
);
});
});
...
@@ -165,6 +218,8 @@ torch::Tensor single_query_attention(const torch::Tensor q,
...
@@ -165,6 +218,8 @@ torch::Tensor single_query_attention(const torch::Tensor q,
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"rotary_cos_"
),
py
::
arg
(
"rotary_sin_"
),
py
::
arg
(
"nnz_head_idx_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
}
flash_attn/layers/rotary.py
View file @
62e98144
...
@@ -169,16 +169,28 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -169,16 +169,28 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000.0
,
interleaved
=
False
,
scale_base
=
None
,
device
=
None
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000.0
,
interleaved
=
False
,
scale_base
=
None
,
pos_idx_in_fp32
=
True
,
device
=
None
):
"""
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
dim
=
dim
self
.
base
=
float
(
base
)
self
.
base
=
float
(
base
)
self
.
pos_idx_in_fp32
=
pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
inv_freq
=
self
.
_compute_inv_freq
(
device
)
dtype
=
torch
.
float32
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
interleaved
=
interleaved
self
.
interleaved
=
interleaved
self
.
scale_base
=
scale_base
self
.
scale_base
=
scale_base
...
@@ -192,31 +204,48 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -192,31 +204,48 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_cos_k_cached
=
None
self
.
_cos_k_cached
=
None
self
.
_sin_k_cached
=
None
self
.
_sin_k_cached
=
None
def
_update_cos_sin_cache
(
self
,
x
,
seqlen_offset
=
0
):
def
_compute_inv_freq
(
self
,
device
=
None
):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
return
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
device
=
device
,
"""
dtype
=
torch
.
float32
)
/
self
.
dim
))
seqlen
=
x
.
shape
[
1
]
+
seqlen_offset
def
_update_cos_sin_cache
(
self
,
seqlen
,
device
=
None
,
dtype
=
None
):
# Reset the tables if the sequence length has changed,
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
# or if we're on a new device (possibly due to tracing for instance)
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
x
.
device
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
or
self
.
_cos_cached
.
dtype
!=
x
.
dtype
):
or
self
.
_cos_cached
.
dtype
!=
dtype
):
self
.
_seq_len_cached
=
seqlen
self
.
_seq_len_cached
=
seqlen
t
=
torch
.
arange
(
seqlen
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# Don't do einsum, it converts fp32 to fp16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if
self
.
pos_idx_in_fp32
:
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
torch
.
float32
)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if
self
.
inv_freq
.
dtype
!=
torch
.
float32
:
inv_freq
=
self
.
_compute_inv_freq
(
device
=
device
)
else
:
inv_freq
=
self
.
inv_freq
else
:
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
inv_freq
=
self
.
inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
.
to
(
device
=
t
.
device
)
)
freqs
=
torch
.
outer
(
t
,
inv_freq
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
x
.
dtype
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
else
:
else
:
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
-
seqlen
//
2
)
/
self
.
scale_base
)
-
seqlen
//
2
)
/
self
.
scale_base
)
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
's -> s 1'
)
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
's -> s 1'
)
# We want the multiplication by scale to happen in fp32
# We want the multiplication by scale to happen in fp32
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
x
.
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
x
.
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
...
@@ -224,7 +253,7 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -224,7 +253,7 @@ class RotaryEmbedding(torch.nn.Module):
seqlen_offset: can be used in generation where the qkv being passed in is only the last
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
token in the batch.
"""
"""
self
.
_update_cos_sin_cache
(
qkv
,
seqlen_offset
)
self
.
_update_cos_sin_cache
(
qkv
.
shape
[
1
]
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
...
...
flash_attn/modules/mha.py
View file @
62e98144
...
@@ -515,8 +515,13 @@ class MHA(nn.Module):
...
@@ -515,8 +515,13 @@ class MHA(nn.Module):
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
context
=
ft_attention
.
single_query_attention
(
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
k_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
# neox_rotary_style
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
...
@@ -637,8 +642,13 @@ class ParallelMHA(nn.Module):
...
@@ -637,8 +642,13 @@ class ParallelMHA(nn.Module):
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
context
=
ft_attention
.
single_query_attention
(
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
k_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
# neox_rotary_style
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
...
...
tests/models/test_llama.py
View file @
62e98144
...
@@ -267,10 +267,11 @@ def test_llama_generation(model_name):
...
@@ -267,10 +267,11 @@ def test_llama_generation(model_name):
del
model
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'HF fp16 logits max diff:
{
hf_error
}
'
)
print
(
f
'HF fp16 logits max diff:
{
hf_error
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment