Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
425dbcb6
Commit
425dbcb6
authored
Jul 23, 2023
by
Tri Dao
Browse files
[MHA] Implement MQA/GQA
parent
ec9f74ab
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
340 additions
and
153 deletions
+340
-153
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+80
-13
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+47
-14
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+209
-121
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+4
-4
tests/models/test_gptj.py
tests/models/test_gptj.py
+0
-1
No files found.
flash_attn/layers/rotary.py
View file @
425dbcb6
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
import
math
import
math
import
torch
import
torch
...
@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch
,
seqlen
,
two
,
nheads
,
headdim
=
kv
.
shape
assert
two
==
2
rotary_seqlen
,
rotary_dim
=
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
# conj=False since this is the forward pass
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
interleaved
=
interleaved
return
kv
@
staticmethod
def
backward
(
ctx
,
dkv
):
cos
,
sin
=
ctx
.
saved_tensors
_
,
seqlen
,
_
,
_
,
headdim
=
dkv
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
# conj=True since this is the backward pass
return
dkv
,
None
,
None
,
None
apply_rotary_emb_kv_
=
ApplyRotaryEmbKV_
.
apply
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
"""
"""
The rotary position embeddings from RoFormer_ (Su et. al).
The rotary position embeddings from RoFormer_ (Su et. al).
...
@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
qkv: (batch, seqlen, 3, nheads, headdim)
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
token in the batch.
"""
"""
self
.
_update_cos_sin_cache
(
qkv
.
shape
[
1
]
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
seqlen
=
qkv
.
shape
[
1
]
if
self
.
scale
is
None
:
self
.
_update_cos_sin_cache
(
seqlen
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
return
apply_rotary_emb_qkv_
(
if
kv
is
None
:
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
if
self
.
scale
is
None
:
None
,
None
,
self
.
interleaved
return
apply_rotary_emb_qkv_
(
)
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
)
else
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
)
else
:
else
:
return
apply_rotary_emb_
qkv
_
(
q
=
qkv
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
q
=
apply_rotary_emb_func
(
self
.
_cos_
k_
cached
[
seqlen_offset
:],
self
.
_sin_
k_
cached
[
seqlen_offset
:],
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
self
.
interleaved
,
True
)
)
if
self
.
scale
is
None
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
)
else
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
)
return
q
,
kv
flash_attn/models/gpt.py
View file @
425dbcb6
...
@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
parallel_kwargs
=
({
'process_group'
:
process_group
,
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
if
process_group
is
not
None
else
{})
num_heads_kv
=
getattr
(
config
,
"n_head_kv"
,
None
)
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
num_heads_kv
=
num_heads_kv
,
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
dropout
=
config
.
attn_pdrop
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
...
@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
assert
inner_dim
%
world_size
==
0
assert
inner_dim
%
world_size
==
0
def
shard_first_dim
(
state_dict
,
key
):
def
shard_first_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
if
key
in
state_dict
:
dim
=
x
.
shape
[
0
]
//
world_size
x
=
state_dict
[
key
]
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
):
def
shard_last_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
if
key
in
state_dict
:
dim
=
x
.
shape
[
-
1
]
//
world_size
x
=
state_dict
[
key
]
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
dim
=
x
.
shape
[
-
1
]
//
world_size
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_qkv_headdim
(
state_dict
,
key
):
def
shard_qkv_headdim
(
state_dict
,
key
):
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
if
key
in
state_dict
:
dim
=
x
.
shape
[
1
]
//
world_size
n_head
=
config
.
n_head
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
n_head_kv
=
getattr
(
config
,
'n_head_kv'
,
n_head
)
'three d ... -> (three d) ...'
)
assert
n_head
%
world_size
==
0
and
n_head_kv
%
world_size
==
0
if
n_head_kv
==
n_head
:
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
dim
=
x
.
shape
[
1
]
//
world_size
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
'three d ... -> (three d) ...'
)
else
:
n_head_per_rank
=
n_head
//
world_size
n_head_kv_per_rank
=
n_head_kv
//
world_size
x
=
rearrange
(
state_dict
[
key
],
'(nheadqkv headdim) ... -> nheadqkv headdim ...'
,
nheadqkv
=
n_head
+
2
*
n_head_kv
)
state_dict
[
key
]
=
rearrange
(
torch
.
cat
([
x
[
rank
*
n_head_per_rank
:(
rank
+
1
)
*
n_head_per_rank
],
x
[
n_head
+
rank
*
n_head_kv_per_rank
:
n_head
+
(
rank
+
1
)
*
n_head_kv_per_rank
],
x
[
n_head
+
n_head_kv
+
rank
*
n_head_kv_per_rank
:
n_head
+
n_head_kv
+
(
rank
+
1
)
*
n_head_kv_per_rank
],
],
dim
=
0
),
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
if
'lm_head.weight'
in
state_dict
:
...
@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
if
rank
!=
0
:
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
,
None
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
if
rank
!=
0
:
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
,
None
)
return
state_dict
return
state_dict
...
@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config):
...
@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config):
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
'n_head_kv'
,
n_head
)
assert
n_head
%
world_size
==
0
and
n_head_kv
%
world_size
==
0
n_head_per_rank
=
n_head
//
world_size
n_head_kv_per_rank
=
n_head_kv
//
world_size
if
key
in
state_dict
:
if
key
in
state_dict
:
xs
=
[
rearrange
(
s
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
for
s
in
state_dicts
]
if
n_head_kv
==
n_head
:
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'three d ... -> (three d) ...'
)
xs
=
[
rearrange
(
s
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'three d ... -> (three d) ...'
)
else
:
xs
=
[
rearrange
(
s
[
key
],
'(nheadqkv headdim) ... -> nheadqkv headdim ...'
,
nheadqkv
=
n_head
+
2
*
n_head_kv
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
([
torch
.
cat
([
x
[:
n_head_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
([
x
[
n_head_per_rank
:
n_head_per_rank
+
n_head_kv_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
([
x
[
-
n_head_kv_per_rank
:]
for
x
in
xs
],
dim
=
0
),
],
dim
=
0
),
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
)
def
combine_gated_mlp
(
state_dicts
,
state_dict
,
key
):
def
combine_gated_mlp
(
state_dicts
,
state_dict
,
key
):
if
key
in
state_dict
:
if
key
in
state_dict
:
...
...
flash_attn/modules/mha.py
View file @
425dbcb6
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
,
repeat
try
:
try
:
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_kvpacked_func
...
@@ -211,7 +211,7 @@ class CrossAttention(nn.Module):
...
@@ -211,7 +211,7 @@ class CrossAttention(nn.Module):
Arguments
Arguments
---------
---------
q: The tensor containing the query. (B, Sq, H, D)
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H
_k
, D)
causal: if passed, will override self.causal
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
False means to mask out. (B, Sk)
...
@@ -219,7 +219,9 @@ class CrossAttention(nn.Module):
...
@@ -219,7 +219,9 @@ class CrossAttention(nn.Module):
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
causal
=
self
.
causal
if
causal
is
None
else
causal
causal
=
self
.
causal
if
causal
is
None
else
causal
seqlen_k
=
kv
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
kv
.
shape
[
3
]
!=
q
.
shape
[
2
]:
# MQA/GQA
kv
=
repeat
(
kv
,
"... hkv d -> ... (hkv g) d"
,
g
=
q
.
shape
[
2
]
//
kv
.
shape
[
3
])
k
,
v
=
kv
.
unbind
(
dim
=
2
)
k
,
v
=
kv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
...
@@ -304,17 +306,52 @@ def _update_kv_cache(kv, inference_params, layer_idx):
...
@@ -304,17 +306,52 @@ def _update_kv_cache(kv, inference_params, layer_idx):
return
kv
return
kv
def
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
layer_idx
,
rotary_emb_dim
,
rotary_emb_base
,
kv
=
None
,
rotary_emb_interleaved
=
False
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
if
kv
is
None
:
q
,
k
,
v
=
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
)
else
:
q
=
rearrange
(
qkv
,
'b 1 h d -> b h d'
)
k
,
v
=
rearrange
(
kv
,
'b 1 two h d -> b two h d'
).
unbind
(
dim
=
1
)
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
q
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
context
=
ft_attention
.
single_query_attention
(
q
,
k
,
v
,
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
rotary_emb_dim
,
rotary_emb_base
,
not
rotary_emb_interleaved
# neox_rotary_style
)
return
rearrange
(
context
,
'b h d -> b 1 h d'
)
class
MHA
(
nn
.
Module
):
class
MHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention
"""
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
num_heads_kv
=
None
,
cross_attn
=
False
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
rotary_emb_interleaved
=
False
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
"""
"""
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
return_residual: whether to return the input x along with the output. This is for
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
to fuse the backward of nn.Linear with the residual connection.
...
@@ -332,8 +369,12 @@ class MHA(nn.Module):
...
@@ -332,8 +369,12 @@ class MHA(nn.Module):
self
.
checkpointing
=
checkpointing
self
.
checkpointing
=
checkpointing
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
assert
self
.
num_heads
%
self
.
num_heads_kv
==
0
,
"num_heads must be divisible by num_heads_kv"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
...
@@ -347,31 +388,23 @@ class MHA(nn.Module):
...
@@ -347,31 +388,23 @@ class MHA(nn.Module):
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
linear_resid_cls
=
(
LinearResidual
if
not
fused_bias_fc
else
partial
(
FusedDense
,
return_residual
=
True
))
else
partial
(
FusedDense
,
return_residual
=
True
))
wqkv_cls
=
linear_cls
if
not
self
.
return_residual
else
linear_resid_cls
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
inner_cross_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
self
.
Wqkv
=
wqkv_cls
(
embed_dim
,
qkv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wqkv
=
linear_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
else
:
self
.
Wqkv
=
linear_resid_cls
(
embed_dim
,
3
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
3
*
embed_dim
,
3
*
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
3
*
embed_dim
)
else
:
else
:
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
self
.
Wkv
=
wqkv_cls
(
embed_dim
,
kv_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
qkv_proj_bias
,
if
self
.
dwconv
:
**
factory_kwargs
)
if
self
.
num_heads_kv
==
self
.
num_heads
:
self
.
dwconv_qkv
=
nn
.
Conv1d
(
qkv_dim
,
qkv_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
qkv_dim
)
else
:
else
:
self
.
Wkv
=
linear_resid_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
qkv_proj_bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
)
groups
=
embed_dim
)
self
.
dwconv_kv
=
nn
.
Conv1d
(
2
*
embed_dim
,
2
*
embed
_dim
,
kernel_size
=
3
,
padding
=
2
,
self
.
dwconv_kv
=
nn
.
Conv1d
(
kv_dim
,
kv
_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
2
*
embed
_dim
)
groups
=
kv
_dim
)
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_cross_attn
=
inner_cross_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
...
@@ -382,15 +415,15 @@ class MHA(nn.Module):
...
@@ -382,15 +415,15 @@ class MHA(nn.Module):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads
,
self
.
head_dim
,
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads
_kv
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
dtype
=
dtype
,
device
=
device
)
else
:
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads
,
self
.
head_dim
//
packsize
,
max_seqlen
,
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads
_kv
,
self
.
head_dim
//
packsize
,
packsize
,
dtype
=
dtype
,
device
=
device
)
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads
,
max_seqlen
,
self
.
head_dim
,
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads
_kv
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
dtype
=
dtype
,
device
=
device
)
return
k_cache
,
v_cache
return
k_cache
,
v_cache
...
@@ -401,6 +434,18 @@ class MHA(nn.Module):
...
@@ -401,6 +434,18 @@ class MHA(nn.Module):
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_single_query_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
return
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
):
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
"""
...
@@ -438,7 +483,8 @@ class MHA(nn.Module):
...
@@ -438,7 +483,8 @@ class MHA(nn.Module):
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
not
self
.
cross_attn
:
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
if
not
self
.
cross_attn
and
self
.
num_heads_kv
==
self
.
num_heads
:
assert
x_kv
is
None
and
mixer_subset
is
None
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
...
@@ -448,71 +494,69 @@ class MHA(nn.Module):
...
@@ -448,71 +494,69 @@ class MHA(nn.Module):
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
if
inference_params
is
None
:
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
)
if
not
self
.
checkpointing
:
if
inference_params
is
None
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
if
(
not
inference_params
.
fused_ft_kernel
)
or
inference_params
.
sequence_len_offset
==
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
=
qkv
[:,
:,
0
]
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
assert
inference_params
.
fused_ft_kernel
context
=
self
.
_apply_rotary_single_query_attention
(
qkv
,
inference_params
)
assert
ft_attention
is
not
None
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
qkv
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
else
:
else
:
if
not
self
.
return_residual
:
if
self
.
cross_attn
:
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
if
not
self
.
return_residual
:
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
else
:
if
x_kv
is
not
None
:
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
else
:
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
else
:
else
:
if
x_kv
is
not
None
:
assert
self
.
num_heads_kv
!=
self
.
num_heads
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
else
:
else
:
kv
,
x
=
self
.
Wkv
(
x
)
qkv
,
x
=
self
.
Wqkv
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
])
q
=
qkv
[...,
:
self
.
num_heads
*
self
.
head_dim
]
kv
=
qkv
[...,
self
.
num_heads
*
self
.
head_dim
:]
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
d
=
self
.
head_dim
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'... (two h
kv
d) -> ... two h
kv
d'
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
'b s d -> b d s'
))[...,
:
-
2
],
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
if
inference_params
is
None
:
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
if
not
self
.
checkpointing
:
or
not
inference_params
.
fused_ft_kernel
):
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
if
self
.
rotary_emb_dim
>
0
:
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
)
if
inference_params
is
None
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
**
kwargs
)
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
kv
=
self
.
_update_kv_cache
(
kv
)
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
False
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
...
@@ -521,7 +565,8 @@ class ParallelMHA(nn.Module):
...
@@ -521,7 +565,8 @@ class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""Multi-head self-attention and cross-attention
"""
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
num_heads_kv
=
None
,
qkv_proj_bias
=
True
,
out_proj_bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
layer_idx
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_dim
=
0
,
rotary_emb_base
=
10000.0
,
rotary_emb_scale_base
=
None
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
rotary_emb_interleaved
=
False
,
use_flash_attn
=
False
,
checkpointing
=
False
,
...
@@ -534,10 +579,19 @@ class ParallelMHA(nn.Module):
...
@@ -534,10 +579,19 @@ class ParallelMHA(nn.Module):
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
use_flash_attn
=
use_flash_attn
self
.
checkpointing
=
checkpointing
self
.
checkpointing
=
checkpointing
self
.
process_group
=
process_group
self
.
world_size
=
process_group
.
size
()
if
process_group
is
not
None
else
1
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
num_heads_kv
=
num_heads_kv
if
num_heads_kv
is
not
None
else
num_heads
self
.
num_heads_per_rank
=
num_heads
//
self
.
world_size
self
.
num_heads_kv_per_rank
=
self
.
num_heads_kv
//
self
.
world_size
assert
self
.
num_heads
%
self
.
num_heads_kv
==
0
,
"num_heads must be divisible by num_heads_kv"
assert
self
.
embed_dim
%
num_heads
==
0
,
"embed_dim must be divisible by num_heads"
assert
self
.
num_heads_kv
%
self
.
world_size
==
0
,
"num_heads_kv must be divisible by world_size"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
kv_dim
=
2
*
self
.
head_dim
*
self
.
num_heads_kv
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
...
@@ -547,7 +601,7 @@ class ParallelMHA(nn.Module):
...
@@ -547,7 +601,7 @@ class ParallelMHA(nn.Module):
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed
_dim
,
process_group
,
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
qkv
_dim
,
process_group
,
bias
=
qkv_proj_bias
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
...
@@ -560,6 +614,41 @@ class ParallelMHA(nn.Module):
...
@@ -560,6 +614,41 @@ class ParallelMHA(nn.Module):
bias
=
out_proj_bias
,
bias
=
out_proj_bias
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads_kv_per_rank
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
return
k_cache
,
v_cache
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
_apply_rotary_single_query_attention
(
self
,
qkv
,
inference_params
,
kv
=
None
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
return
_apply_rotary_single_query_attention
(
qkv
,
inference_params
,
self
.
layer_idx
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
kv
=
kv
,
rotary_emb_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
)
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
"""
Arguments:
Arguments:
...
@@ -569,55 +658,54 @@ class ParallelMHA(nn.Module):
...
@@ -569,55 +658,54 @@ class ParallelMHA(nn.Module):
(in case batch is small).
(in case batch is small).
"""
"""
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
if
seqlen
is
None
:
if
seqlen
is
not
None
:
qkv
=
rearrange
(
qkv
,
"(b s) ... -> b s ..."
,
s
=
seqlen
)
seqlen_offset
=
0
if
inference_params
is
None
else
inference_params
.
sequence_len_offset
if
self
.
num_heads_kv
==
self
.
num_heads
:
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
else
:
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
qkv
=
rearrange
(
qkv
,
'(b s) (three h d) -> b s three h d'
,
s
=
seqlen
,
three
=
3
,
or
not
inference_params
.
fused_ft_kernel
):
d
=
self
.
head_dim
)
if
self
.
rotary_emb_dim
>
0
:
if
inference_params
is
None
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
seqlen_offset
)
if
self
.
rotary_emb_dim
>
0
:
if
inference_params
is
None
:
qkv
=
self
.
rotary_emb
(
qkv
)
if
not
self
.
checkpointing
:
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
q
=
qkv
[:,
:,
0
]
kv
=
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
,
self
.
layer_idx
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inn
er_att
n
,
qkv
,
**
kwarg
s
)
context
=
self
.
_apply_rotary_single_qu
er
y
_att
ention
(
qkv
,
inference_param
s
)
else
:
else
:
if
(
not
inference_params
.
fused_ft_kernel
)
or
inference_params
.
sequence_len_offset
==
0
:
q
=
rearrange
(
qkv
[...,
:
self
.
num_heads_per_rank
*
self
.
head_dim
],
"... (h d) -> ... h d"
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
qkv
[...,
self
.
num_heads_per_rank
*
self
.
head_dim
:],
"... (two hkv d) -> ... two hkv d"
,
two
=
2
,
d
=
self
.
head_dim
)
if
(
inference_params
is
None
or
inference_params
.
sequence_len_offset
==
0
or
not
inference_params
.
fused_ft_kernel
):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
,
kv
=
self
.
rotary_emb
(
q
,
kv
,
seqlen_offset
=
seqlen_offset
)
q
=
qkv
[:,
:,
0
]
if
inference_params
is
None
:
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
if
not
self
.
checkpointing
:
kv
=
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
,
self
.
layer_idx
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
**
kwargs
)
# If we're processing the prompt, causal=None (use self.causal).
else
:
# If we're decoding, then causal=False.
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_cross_attn
,
q
,
kv
,
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
**
kwargs
)
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
kv
=
self
.
_update_kv_cache
(
kv
,
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
else
:
assert
inference_params
.
fused_ft_kernel
context
=
self
.
_apply_rotary_single_query_attention
(
q
,
inference_params
,
kv
=
kv
)
assert
ft_attention
is
not
None
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
batch_start
=
inference_params
.
batch_size_offset
if
seqlen
is
not
None
:
batch_end
=
batch_start
+
qkv
.
shape
[
0
]
context
=
rearrange
(
context
,
'b s d -> (b s) d'
)
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
rotary_emb_base
=
self
.
rotary_emb
.
base
if
self
.
rotary_emb_dim
>
0
else
0
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
None
,
# rotary_cos_
None
,
# rotary_sin_
None
,
# nnz_head_idx
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
rotary_emb_base
,
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
if
seqlen
is
None
:
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
else
:
context
=
rearrange
(
context
,
'b s h d -> (b s) (h d)'
)
out
=
self
.
out_proj
(
context
)
out
=
self
.
out_proj
(
context
)
return
out
return
out
tests/models/test_gpt_generation.py
View file @
425dbcb6
...
@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and
he
"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
max_length
=
25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
...
@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name):
...
@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
eos_token_id
=
tokenizer
.
eos_token_id
eos_token_id
=
tokenizer
.
eos_token_id
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and
he
"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
60
max_length
=
25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
...
...
tests/models/test_gptj.py
View file @
425dbcb6
...
@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name):
...
@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
batch_size
=
2
batch_size
=
2
max_seqlen
=
256
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment