Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
bec5b3d3
Commit
bec5b3d3
authored
Aug 16, 2023
by
Tri Dao
Browse files
[MHA] Run black on mha.py
parent
cb0daccc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
374 additions
and
175 deletions
+374
-175
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+374
-175
No files found.
flash_attn/modules/mha.py
View file @
bec5b3d3
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
import
math
import
math
from
functools
import
partial
from
functools
import
partial
...
@@ -6,18 +6,21 @@ from functools import partial
...
@@ -6,18 +6,21 @@ from functools import partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
try
:
try
:
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
from
flash_attn
import
(
from
flash_attn
import
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
)
except
ImportError
:
except
ImportError
:
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
FusedDense
,
RowParallelLinear
except
ImportError
:
except
ImportError
:
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
,
None
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
,
None
...
@@ -42,10 +45,11 @@ class FlashSelfAttention(nn.Module):
...
@@ -42,10 +45,11 @@ class FlashSelfAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
'
FlashAttention is not installed
'
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"
FlashAttention is not installed
"
assert
flash_attn_qkvpacked_func
is
not
None
,
'
FlashAttention is not installed
'
assert
flash_attn_qkvpacked_func
is
not
None
,
"
FlashAttention is not installed
"
self
.
causal
=
causal
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
...
@@ -76,12 +80,20 @@ class FlashSelfAttention(nn.Module):
...
@@ -76,12 +80,20 @@ class FlashSelfAttention(nn.Module):
assert
max_seqlen
is
not
None
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_varlen_qkvpacked_func
(
return
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
qkv
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
cu_seqlens
,
max_seqlen
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
)
)
else
:
else
:
return
flash_attn_qkvpacked_func
(
qkv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
return
flash_attn_qkvpacked_func
(
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
qkv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
)
class
FlashCrossAttention
(
nn
.
Module
):
class
FlashCrossAttention
(
nn
.
Module
):
...
@@ -94,16 +106,25 @@ class FlashCrossAttention(nn.Module):
...
@@ -94,16 +106,25 @@ class FlashCrossAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
'
FlashAttention is not installed
'
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"
FlashAttention is not installed
"
assert
flash_attn_kvpacked_func
is
not
None
,
'
FlashAttention is not installed
'
assert
flash_attn_kvpacked_func
is
not
None
,
"
FlashAttention is not installed
"
self
.
causal
=
causal
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
):
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
,
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
Arguments
Arguments
---------
---------
...
@@ -130,16 +151,27 @@ class FlashCrossAttention(nn.Module):
...
@@ -130,16 +151,27 @@ class FlashCrossAttention(nn.Module):
assert
max_seqlen_k
is
not
None
assert
max_seqlen_k
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_varlen_kvpacked_func
(
return
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
)
)
else
:
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
return
flash_attn_kvpacked_func
(
q
,
kv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
return
flash_attn_kvpacked_func
(
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
)
q
,
kv
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
,
)
class
SelfAttention
(
nn
.
Module
):
class
SelfAttention
(
nn
.
Module
):
...
@@ -152,6 +184,7 @@ class SelfAttention(nn.Module):
...
@@ -152,6 +184,7 @@ class SelfAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
self
.
causal
=
causal
self
.
causal
=
causal
...
@@ -171,22 +204,25 @@ class SelfAttention(nn.Module):
...
@@ -171,22 +204,25 @@ class SelfAttention(nn.Module):
causal
=
self
.
causal
if
causal
is
None
else
causal
causal
=
self
.
causal
if
causal
is
None
else
causal
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'
bthd,bshd->bhts
'
,
q
,
k
*
softmax_scale
)
scores
=
torch
.
einsum
(
"
bthd,bshd->bhts
"
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
((
batch_size
,
seqlen
),
-
10000.0
,
dtype
=
scores
.
dtype
,
padding_mask
=
torch
.
full
(
device
=
scores
.
device
)
(
batch_size
,
seqlen
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'
b s -> b 1 1 s
'
)
scores
=
scores
+
rearrange
(
padding_mask
,
"
b s -> b 1 1 s
"
)
if
causal
:
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
'
bhts,bshd->bthd
'
,
attention_drop
,
v
)
output
=
torch
.
einsum
(
"
bhts,bshd->bthd
"
,
attention_drop
,
v
)
return
output
return
output
...
@@ -200,6 +236,7 @@ class CrossAttention(nn.Module):
...
@@ -200,6 +236,7 @@ class CrossAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
super
().
__init__
()
self
.
causal
=
causal
self
.
causal
=
causal
...
@@ -224,43 +261,48 @@ class CrossAttention(nn.Module):
...
@@ -224,43 +261,48 @@ class CrossAttention(nn.Module):
kv
=
repeat
(
kv
,
"... hkv d -> ... (hkv g) d"
,
g
=
q
.
shape
[
2
]
//
kv
.
shape
[
3
])
kv
=
repeat
(
kv
,
"... hkv d -> ... (hkv g) d"
,
g
=
q
.
shape
[
2
]
//
kv
.
shape
[
3
])
k
,
v
=
kv
.
unbind
(
dim
=
2
)
k
,
v
=
kv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'
bthd,bshd->bhts
'
,
q
,
k
*
softmax_scale
)
scores
=
torch
.
einsum
(
"
bthd,bshd->bhts
"
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
((
batch_size
,
seqlen_k
),
-
10000.0
,
dtype
=
scores
.
dtype
,
padding_mask
=
torch
.
full
(
device
=
scores
.
device
)
(
batch_size
,
seqlen_k
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'
b s -> b 1 1 s
'
)
scores
=
scores
+
rearrange
(
padding_mask
,
"
b s -> b 1 1 s
"
)
if
causal
:
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen_q
,
seqlen_k
),
-
10000.0
,
causal_mask
=
torch
.
triu
(
device
=
scores
.
device
),
1
)
torch
.
full
((
seqlen_q
,
seqlen_k
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
self
.
drop
(
attention
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
'
bhts,bshd->bthd
'
,
attention_drop
,
v
)
output
=
torch
.
einsum
(
"
bhts,bshd->bthd
"
,
attention_drop
,
v
)
return
output
return
output
class
LinearResidual
(
nn
.
Linear
):
class
LinearResidual
(
nn
.
Linear
):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
"""
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input
),
input
return
super
().
forward
(
input
),
input
def
_update_kv_cache
(
kv
,
inference_params
,
layer_idx
):
def
_update_kv_cache
(
kv
,
inference_params
,
layer_idx
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
"""
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference.
num_heads
,
head_dim
=
kv
.
shape
[
-
2
:]
num_heads
,
head_dim
=
kv
.
shape
[
-
2
:]
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
if
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
kv_cache
=
torch
.
empty
(
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
inference_params
.
max_batch_size
,
num_heads
,
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
inference_params
.
max_sequence_len
,
2
,
num_heads
,
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
,
)
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
kv_cache
else
:
else
:
...
@@ -292,22 +334,30 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -292,22 +334,30 @@ def _update_kv_cache(kv, inference_params, layer_idx):
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
if
kv_cache
is
not
None
:
if
kv_cache
is
not
None
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
k_cache
=
rearrange
(
packsize
=
packsize
).
contiguous
()
kv_cache
[:,
:,
0
],
"b s h (d packsize) -> b h d s packsize"
,
packsize
=
packsize
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
"b s h d -> b h s d"
).
contiguous
()
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
(
k_cache
,
v_cache
)
inference_params
.
key_value_memory_dict
[
layer_idx
]
=
(
k_cache
,
v_cache
)
else
:
else
:
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
0
],
'
b s h (d packsize) -> b h d s packsize
'
,
packsize
=
packsize
kv
[:,
:,
0
],
"
b s h (d packsize) -> b h d s packsize
"
,
packsize
=
packsize
)
)
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
1
],
'
b s h d -> b h s d
'
kv
[:,
:,
1
],
"
b s h d -> b h s d
"
)
)
return
kv
return
kv
def
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
layer_idx
,
rotary_emb_dim
,
def
_apply_rotary_single_query_attention
(
rotary_emb_base
,
kv
=
None
,
rotary_emb_interleaved
=
False
):
qkv
,
inference_params
,
layer_idx
,
rotary_emb_dim
,
rotary_emb_base
,
kv
=
None
,
rotary_emb_interleaved
=
False
,
):
"""
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
q of shape (batch_size, 1, nheads, head_dim)
...
@@ -316,17 +366,22 @@ def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotar
...
@@ -316,17 +366,22 @@ def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotar
assert
inference_params
.
fused_ft_kernel
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
assert
ft_attention
is
not
None
if
kv
is
None
:
if
kv
is
None
:
q
,
k
,
v
=
rearrange
(
qkv
,
'
b 1 three h d -> b three h d
'
).
unbind
(
dim
=
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"
b 1 three h d -> b three h d
"
).
unbind
(
dim
=
1
)
else
:
else
:
q
=
rearrange
(
qkv
,
'
b 1 h d -> b h d
'
)
q
=
rearrange
(
qkv
,
"
b 1 h d -> b h d
"
)
k
,
v
=
rearrange
(
kv
,
'
b 1 two h d -> b two h d
'
).
unbind
(
dim
=
1
)
k
,
v
=
rearrange
(
kv
,
"
b 1 two h d -> b two h d
"
).
unbind
(
dim
=
1
)
batch_start
=
inference_params
.
batch_size_offset
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
q
.
shape
[
0
]
batch_end
=
batch_start
+
q
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
lengths_per_sample
=
(
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
context
=
ft_attention
.
single_query_attention
(
context
=
ft_attention
.
single_query_attention
(
q
,
k
,
v
,
q
,
k
,
v
,
k_cache
[
batch_start
:
batch_end
],
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
lengths_per_sample
,
...
@@ -334,29 +389,47 @@ def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotar
...
@@ -334,29 +389,47 @@ def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotar
None
,
# rotary_sin_
None
,
# rotary_sin_
None
,
# nnz_head_idx
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
inference_params
.
sequence_len_offset
,
rotary_emb_dim
,
rotary_emb_base
,
rotary_emb_dim
,
not
rotary_emb_interleaved
# neox_rotary_style
rotary_emb_base
,
not
rotary_emb_interleaved
,
# neox_rotary_style
)
)
return
rearrange
(
context
,
'
b h d -> b 1 h d
'
)
return
rearrange
(
context
,
"
b h d -> b 1 h d
"
)
class
MHA
(
nn
.
Module
):
class
MHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention"""
"""
def
__init__
(
def
__init__
(
self
,
embed_dim
,
num_heads
,
num_heads_kv
=
None
,
cross_attn
=
False
,
self
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
embed_dim
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
num_heads
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
num_heads_kv
=
None
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
cross_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
,
)
->
None
:
"""
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
to fuse the backward of nn.Linear with the residual connection.
"""
"""
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
cross_attn
=
cross_attn
self
.
cross_attn
=
cross_attn
...
@@ -370,24 +443,31 @@ class MHA(nn.Module):
...
@@ -370,24 +443,31 @@ class MHA(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
self
.
num_heads
%
self
.
num_heads_kv
==
0
,
"num_heads must be divisible by num_heads_kv"
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
not
cross_attn
,
"MHA with rotary embedding does not support cross-attention yet"
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
linear_resid_cls
=
(
else
partial
(
FusedDense
,
return_residual
=
True
))
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
)
)
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
...
@@ -398,40 +478,57 @@ class MHA(nn.Module):
...
@@ -398,40 +478,57 @@ class MHA(nn.Module):
self
.
Wkv
=
wqkv_cls
(
embed_dim
,
kv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wkv
=
wqkv_cls
(
embed_dim
,
kv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
if
self
.
dwconv
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
qkv_dim
,
qkv_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_qkv
=
nn
.
Conv1d
(
groups
=
qkv_dim
)
qkv_dim
,
qkv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
qkv_dim
)
else
:
else
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_q
=
nn
.
Conv1d
(
groups
=
embed_dim
)
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv_dim
,
kernel_size
=
3
,
padding
=
2
,
)
groups
=
kv_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
kv_dim
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
attention_dropout
=
dropout
)
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
return
torch
.
empty
(
dtype
=
dtype
,
device
=
device
)
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
else
:
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
self
.
head_dim
//
packsize
,
k_cache
=
torch
.
empty
(
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
)
batch_size
,
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
max_seqlen
,
self
.
head_dim
,
self
.
num_heads_kv
,
dtype
=
dtype
,
device
=
device
)
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
,
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
return
k_cache
,
v_cache
return
k_cache
,
v_cache
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
"""
assert
not
self
.
dwconv
,
"Generation does not support dwconv yet"
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_single_query_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
def
_apply_rotary_single_query_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
...
@@ -442,12 +539,28 @@ class MHA(nn.Module):
...
@@ -442,12 +539,28 @@ class MHA(nn.Module):
"""
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
return
_apply_rotary_single_query_attention
(
return
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
qkv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
)
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
):
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
,
):
"""
"""
Arguments:
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...
@@ -481,8 +594,11 @@ class MHA(nn.Module):
...
@@ -481,8 +594,11 @@ class MHA(nn.Module):
assert
cu_seqlens
is
None
and
max_seqlen
is
None
assert
cu_seqlens
is
None
and
max_seqlen
is
None
assert
not
self
.
dwconv
assert
not
self
.
dwconv
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
kwargs
=
(
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
{
"cu_seqlens"
:
cu_seqlens
,
"max_seqlen"
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
"key_padding_mask"
:
key_padding_mask
,
**
kwargs
}
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
assert
x_kv
is
None
and
mixer_subset
is
None
...
@@ -491,19 +607,22 @@ class MHA(nn.Module):
...
@@ -491,19 +607,22 @@ class MHA(nn.Module):
else
:
else
:
qkv
,
x
=
self
.
Wqkv
(
x
)
qkv
,
x
=
self
.
Wqkv
(
x
)
if
self
.
dwconv
:
if
self
.
dwconv
:
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
qkv
=
rearrange
(
'b d s -> b s d'
).
contiguous
()
self
.
dwconv_qkv
(
rearrange
(
qkv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
).
contiguous
()
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
qkv
=
rearrange
(
qkv
,
"... (three h d) -> ... three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
or
not
inference_params
.
fused_ft_kernel
):
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
)
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
)
if
inference_params
is
None
:
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
**
kwargs
)
else
:
else
:
q
=
qkv
[:,
:,
0
]
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
...
@@ -530,25 +649,31 @@ class MHA(nn.Module):
...
@@ -530,25 +649,31 @@ class MHA(nn.Module):
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
else
:
else
:
qkv
,
x
=
self
.
Wqkv
(
x
)
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
q
=
rearrange
(
q
,
'
... (h d) -> ... h d
'
,
d
=
self
.
head_dim
)
q
=
rearrange
(
q
,
"
... (h d) -> ... h d
"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'
... (two hkv d) -> ... two hkv d
'
,
two
=
2
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
"
... (two hkv d) -> ... two hkv d
"
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
'b s d -> b d s'
))[...,
:
-
2
],
q
=
rearrange
(
'b d s -> b s d'
).
contiguous
()
self
.
dwconv_q
(
rearrange
(
q
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
kv
=
rearrange
(
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
self
.
dwconv_kv
(
rearrange
(
kv
,
"b s d -> b d s"
))[...,
:
-
2
],
"b d s -> b s d"
or
not
inference_params
.
fused_ft_kernel
):
).
contiguous
()
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
)
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
)
if
inference_params
is
None
:
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
**
kwargs
)
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're processing the prompt, causal=None (use self.causal).
...
@@ -557,21 +682,36 @@ class MHA(nn.Module):
...
@@ -557,21 +682,36 @@ class MHA(nn.Module):
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'
... h d -> ... (h d)
'
))
out
=
self
.
out_proj
(
rearrange
(
context
,
"
... h d -> ... (h d)
"
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelMHA
(
nn
.
Module
):
class
ParallelMHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention"""
"""
def
__init__
(
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
num_heads_kv
=
None
,
self
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
embed_dim
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
num_heads
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
process_group
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
num_heads_kv
=
None
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
)
->
None
:
qkv_proj_bias
=
True
,
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
causal
=
causal
...
@@ -586,55 +726,93 @@ class ParallelMHA(nn.Module):
...
@@ -586,55 +726,93 @@ class ParallelMHA(nn.Module):
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
self
.
num_heads_per_rank
=
num_heads
//
self
.
world_size
self
.
num_heads_per_rank
=
num_heads
//
self
.
world_size
self
.
num_heads_kv_per_rank
=
self
.
num_heads_kv
//
self
.
world_size
self
.
num_heads_kv_per_rank
=
self
.
num_heads_kv
//
self
.
world_size
assert
self
.
num_heads
%
self
.
num_heads_kv
==
0
,
"num_heads must be divisible by num_heads_kv"
assert
(
self
.
num_heads
%
self
.
num_heads_kv
==
0
),
"num_heads must be divisible by num_heads_kv"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
assert
self
.
num_heads_kv
%
self
.
world_size
==
0
,
"num_heads_kv must be divisible by world_size"
assert
(
self
.
num_heads_kv
%
self
.
world_size
==
0
),
"num_heads_kv must be divisible by world_size"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
"rotary_emb is not installed"
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
base
=
rotary_emb_base
,
scale_base
=
rotary_emb_scale_base
,
scale_base
=
rotary_emb_scale_base
,
interleaved
=
rotary_emb_interleaved
,
device
=
device
)
interleaved
=
rotary_emb_interleaved
,
device
=
device
,
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
"fused_dense is not installed"
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
qkv_dim
,
process_group
,
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
qkv_dim
,
process_group
,
bias
=
qkv_proj_bias
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
attention_dropout
=
dropout
)
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
bias
=
out_proj_bias
,
bias
=
out_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
,
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
return
torch
.
empty
(
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
else
:
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
//
packsize
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
)
max_seqlen
,
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
max_seqlen
,
packsize
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
dtype
=
dtype
,
device
=
device
,
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
,
)
return
k_cache
,
v_cache
return
k_cache
,
v_cache
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
"""
assert
self
.
layer_idx
is
not
None
,
"Generation requires layer_idx in the constructor"
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_single_query_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
def
_apply_rotary_single_query_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
...
@@ -645,8 +823,15 @@ class ParallelMHA(nn.Module):
...
@@ -645,8 +823,15 @@ class ParallelMHA(nn.Module):
"""
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
return
_apply_rotary_single_query_attention
(
return
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
qkv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
)
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
...
@@ -662,9 +847,12 @@ class ParallelMHA(nn.Module):
...
@@ -662,9 +847,12 @@ class ParallelMHA(nn.Module):
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
if
self
.
num_heads_kv
==
self
.
num_heads
:
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
if
(
or
not
inference_params
.
fused_ft_kernel
):
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
)
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
)
if
inference_params
is
None
:
if
inference_params
is
None
:
...
@@ -682,20 +870,31 @@ class ParallelMHA(nn.Module):
...
@@ -682,20 +870,31 @@ class ParallelMHA(nn.Module):
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
else
:
else
:
q
=
rearrange
(
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
q
=
rearrange
(
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
kv
=
rearrange
(
qkv
[...,
self
.
num_heads_per_rank
*
self
.
head_dim
:],
"... (h d) -> ... h d"
,
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
d
=
self
.
head_dim
,
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
)
or
not
inference_params
.
fused_ft_kernel
):
kv
=
rearrange
(
qkv
[...,
self
.
num_heads_per_rank
*
self
.
head_dim
:],
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
,
)
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
)
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
)
if
inference_params
is
None
:
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
**
kwargs
)
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're processing the prompt, causal=None (use self.causal).
...
@@ -704,8 +903,8 @@ class ParallelMHA(nn.Module):
...
@@ -704,8 +903,8 @@ class ParallelMHA(nn.Module):
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
rearrange
(
context
,
'
b s h d -> b s (h d)
'
)
context
=
rearrange
(
context
,
"
b s h d -> b s (h d)
"
)
if
seqlen
is
not
None
:
if
seqlen
is
not
None
:
context
=
rearrange
(
context
,
'
b s d -> (b s) d
'
)
context
=
rearrange
(
context
,
"
b s d -> (b s) d
"
)
out
=
self
.
out_proj
(
context
)
out
=
self
.
out_proj
(
context
)
return
out
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment