Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5ce89059
Unverified
Commit
5ce89059
authored
Jun 12, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 12, 2023
Browse files
feat(server): pre-allocate past key values for flash causal LM (#412)
parent
ca650e5b
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
494 additions
and
345 deletions
+494
-345
server/Makefile-flash-att
server/Makefile-flash-att
+2
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+73
-41
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+78
-43
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+114
-68
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+73
-52
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+154
-139
No files found.
server/Makefile-flash-att
View file @
5ce89059
flash_att_commit :=
d478eeec8f16c7939c54e4617dbd36f59b8eeed7
flash_att_commit :=
06ece1a1525ebcf4e183ac76b1e5108d2872f57f
flash-attention:
# Clone flash attention
pip install packaging
git clone https://github.com/
HazyResearch
/flash-attention.git
git clone https://github.com/
OlivierDehaene
/flash-attention.git
build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit)
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
5ce89059
...
...
@@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module):
self
.
rotary_emb
(
qkv
[:,
1
],
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
qkv
[:,
1
:]
...
...
@@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module):
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
else
:
query
=
qkv
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
qkv
[:,
1
:]
layer_past
[
past_present_indices
]
=
qkv
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
...
...
@@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module):
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
...
...
@@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
# faster post attention rms norm
...
...
@@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
input_ids
),
len
(
self
.
layers
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
...
...
@@ -360,24 +375,35 @@ class FlashLlamaModel(torch.nn.Module):
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
past_key_values
[:,
i
],
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
layers
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states
,
present
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
5ce89059
...
...
@@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module):
self
.
rotary_emb
(
qkv
[:,
1
],
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
qkv
[:,
1
:]
...
...
@@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module):
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
else
:
query
=
qkv
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
qkv
[:,
1
:]
layer_past
[
past_present_indices
]
=
qkv
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
...
...
@@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
if
self
.
use_parallel_residual
:
ln1_hidden_states
,
_
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
ln2_hidden_states
,
_
=
self
.
post_attention_layernorm
(
hidden_states
)
...
...
@@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
input_ids
),
len
(
self
.
layers
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
...
...
@@ -367,24 +385,35 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
past_key_values
[:,
i
],
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
layers
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
final_layer_norm
(
hidden_states
,
residual
)
...
...
@@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states
,
present
=
self
.
gpt_neox
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
5ce89059
...
...
@@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
...
...
@@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
kv
[:,
0
]
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
kv
# Expand to query shape
...
...
@@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
0
]
,
kv
[:,
1
]
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
kv
layer_past
[
past_present_indices
]
=
kv
# Expand to query shape
kv
=
layer_past
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
0
]
,
kv
[:,
1
]
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
self
.
num_groups
,
self
.
num_heads
+
2
,
self
.
head_size
)
...
...
@@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
kv
[:,
:,
0
]
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
)
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
kv
# Expand to query shape
...
...
@@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
:,
0
]
,
kv
[:,
:,
1
]
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
1
)
,
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
kv
layer_past
[
past_present_indices
]
=
kv
# Expand to query shape
kv
=
(
layer_past
.
unsqueeze
(
2
)
...
...
@@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
:,
0
]
,
kv
[:,
:,
1
]
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
1
)
,
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
if
self
.
parallel_attn
:
ln_hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
...
...
@@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module):
ln_hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
mlp_output
=
self
.
mlp
(
ln_hidden_states
)
...
...
@@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
ln_attn
,
residual
=
self
.
ln_attn
(
hidden_states
,
residual
)
ln_mlp
,
_
=
self
.
ln_mlp
(
residual
)
...
...
@@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module):
ln_attn
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
# MLP.
...
...
@@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
input_ids
),
len
(
self
.
h
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
*
self
.
cache_size
,
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
...
...
@@ -620,24 +651,33 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
torch
.
select
(
past_key_values
,
dim
=
1
,
index
=
i
),
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
h
),
*
self
.
cache_size
,
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
...
@@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
5ce89059
...
...
@@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports
import
flash_attn_cuda
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def
forward
(
self
,
hidden_states
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
c_attn
(
hidden_states
)
...
...
@@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module):
key_value
=
key_value
.
view
(
-
1
,
2
,
1
,
self
.
head_size
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
key_value
# Expand from 1 to num_heads
...
...
@@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
key_value
[:,
0
]
,
key_value
[:,
1
]
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
key_value
layer_past
[
past_present_indices
]
=
key_value
# Expand from 1 to num_heads
key_value
=
layer_past
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
key_value
[:,
0
]
,
key_value
[:,
1
]
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -277,21 +285,27 @@ class Block(nn.Module):
self
,
hidden_states
,
residual
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
hidden_states
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
hidden_states
,
residual
=
self
.
ln_2
(
hidden_states
,
residual
)
...
...
@@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
...
...
@@ -352,44 +369,42 @@ class FlashSantacoderModel(nn.Module):
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
self
.
h
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
1
,
self
.
head_size
,
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_zeros
(
(
len
(
input_ids
),
len
(
self
.
h
),
2
,
1
,
self
.
head_size
)
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
torch
.
select
(
past_key_values
,
dim
=
1
,
index
=
i
),
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
h
),
2
,
1
,
self
.
head_size
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
...
@@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
5ce89059
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment