Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5ce89059
Unverified
Commit
5ce89059
authored
Jun 12, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 12, 2023
Browse files
feat(server): pre-allocate past key values for flash causal LM (#412)
parent
ca650e5b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
494 additions
and
345 deletions
+494
-345
server/Makefile-flash-att
server/Makefile-flash-att
+2
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+73
-41
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+78
-43
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+114
-68
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+73
-52
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+154
-139
No files found.
server/Makefile-flash-att
View file @
5ce89059
flash_att_commit :=
d478eeec8f16c7939c54e4617dbd36f59b8eeed7
flash_att_commit :=
06ece1a1525ebcf4e183ac76b1e5108d2872f57f
flash-attention:
# Clone flash attention
pip install packaging
git clone https://github.com/
HazyResearch
/flash-attention.git
git clone https://github.com/
OlivierDehaene
/flash-attention.git
build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit)
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
5ce89059
...
...
@@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module):
self
.
rotary_emb
(
qkv
[:,
1
],
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
qkv
[:,
1
:]
...
...
@@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module):
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
else
:
query
=
qkv
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
qkv
[:,
1
:]
layer_past
[
past_present_indices
]
=
qkv
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
...
...
@@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module):
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
...
...
@@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
# faster post attention rms norm
...
...
@@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
input_ids
),
len
(
self
.
layers
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
...
...
@@ -360,24 +375,35 @@ class FlashLlamaModel(torch.nn.Module):
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
past_key_values
[:,
i
],
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
layers
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
...
@@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states
,
present
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
5ce89059
...
...
@@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module):
self
.
rotary_emb
(
qkv
[:,
1
],
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
qkv
[:,
1
:]
...
...
@@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module):
qkv
[:,
1
],
qkv
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
else
:
query
=
qkv
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
qkv
[:,
1
:]
layer_past
[
past_present_indices
]
=
qkv
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
...
...
@@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
if
self
.
use_parallel_residual
:
ln1_hidden_states
,
_
=
self
.
input_layernorm
(
hidden_states
)
...
...
@@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
ln2_hidden_states
,
_
=
self
.
post_attention_layernorm
(
hidden_states
)
...
...
@@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
input_ids
),
len
(
self
.
layers
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
...
...
@@ -367,24 +385,35 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
past_key_values
[:,
i
],
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
layers
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
final_layer_norm
(
hidden_states
,
residual
)
...
...
@@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states
,
present
=
self
.
gpt_neox
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
5ce89059
...
...
@@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
...
...
@@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
kv
[:,
0
]
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
kv
# Expand to query shape
...
...
@@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
0
]
,
kv
[:,
1
]
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
kv
layer_past
[
past_present_indices
]
=
kv
# Expand to query shape
kv
=
layer_past
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
0
]
,
kv
[:,
1
]
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
self
.
num_groups
,
self
.
num_heads
+
2
,
self
.
head_size
)
...
...
@@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
kv
[:,
:,
0
]
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
)
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
kv
# Expand to query shape
...
...
@@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
:,
0
]
,
kv
[:,
:,
1
]
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
1
)
,
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
kv
layer_past
[
past_present_indices
]
=
kv
# Expand to query shape
kv
=
(
layer_past
.
unsqueeze
(
2
)
...
...
@@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
kv
[:,
:,
0
]
,
kv
[:,
:,
1
]
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
0
)
,
torch
.
select
(
kv
,
dim
=
2
,
index
=
1
)
,
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
if
self
.
parallel_attn
:
ln_hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
...
...
@@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module):
ln_hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
mlp_output
=
self
.
mlp
(
ln_hidden_states
)
...
...
@@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module):
hidden_states
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
...
...
@@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module):
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
ln_attn
,
residual
=
self
.
ln_attn
(
hidden_states
,
residual
)
ln_mlp
,
_
=
self
.
ln_mlp
(
residual
)
...
...
@@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module):
ln_attn
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
# MLP.
...
...
@@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
input_ids
),
len
(
self
.
h
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
*
self
.
cache_size
,
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
...
...
@@ -620,24 +651,33 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
torch
.
select
(
past_key_values
,
dim
=
1
,
index
=
i
),
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
h
),
*
self
.
cache_size
,
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
...
@@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
5ce89059
...
...
@@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports
import
flash_attn_cuda
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def
forward
(
self
,
hidden_states
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
qkv
=
self
.
c_attn
(
hidden_states
)
...
...
@@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module):
key_value
=
key_value
.
view
(
-
1
,
2
,
1
,
self
.
head_size
)
# Prefill
if
layer_past_present_indices
is
None
:
if
prefill
:
# Copy to layer past
layer_past
[...]
=
key_value
# Expand from 1 to num_heads
...
...
@@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
key_value
[:,
0
]
,
key_value
[:,
1
]
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq
,
end_seq
,
max_s
,
max_s
,
0.0
,
...
...
@@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_
past_present_indices
]
=
key_value
layer_past
[
past_present_indices
]
=
key_value
# Expand from 1 to num_heads
key_value
=
layer_past
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda
.
fwd
(
query
,
key_value
[:,
0
]
,
key_value
[:,
1
]
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
0
)
,
torch
.
select
(
key_value
,
dim
=
1
,
index
=
1
)
,
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
start_seq_q
,
end_seq_q
,
start_seq
,
end_seq
,
1
,
max_s
,
0.0
,
...
...
@@ -277,21 +285,27 @@ class Block(nn.Module):
self
,
hidden_states
,
residual
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
):
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
hidden_states
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past
,
layer_
past_present_indices
,
cu_seqlens_q
,
past_present_indices
,
prefill
,
)
hidden_states
,
residual
=
self
.
ln_2
(
hidden_states
,
residual
)
...
...
@@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
past_present_indices
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
...
...
@@ -352,44 +369,42 @@ class FlashSantacoderModel(nn.Module):
# Prefill
if
past_key_values
is
None
:
assert
pre_allocate_past_size
is
not
None
prefill
=
True
# Create past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
self
.
h
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
1
,
self
.
head_size
,
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values
=
hidden_states
.
new_zeros
(
(
len
(
input_ids
),
len
(
self
.
h
),
2
,
1
,
self
.
head_size
)
)
)
layer_past_present_indices
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
slice_past_index
=
None
prefill
=
False
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
# We added padding that we now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cu_seqlens
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
layer_past_key_values
,
layer_past_present_indices
,
cu_seqlens_q
,
torch
.
select
(
past_key_values
,
dim
=
1
,
index
=
i
),
past_present_indices
,
prefill
,
)
if
prefill
:
present
=
past_key_values
# Create padded past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
pre_allocate_past_size
,
len
(
self
.
h
),
2
,
1
,
self
.
head_size
)
)
# We slice only once instead of at every layer
past_key_values
[
past_present_indices
]
=
present
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
...
@@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self
,
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
start_seq
,
end_seq
,
start_seq_q
,
end_seq_q
,
max_s
,
past_present_indices
,
past_key_values
,
pre_allocate_past_size
,
)
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
5ce89059
...
...
@@ -3,8 +3,6 @@ import torch.distributed
import
numpy
as
np
from
torch.nn
import
functional
as
F
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
,
PreTrainedModel
...
...
@@ -34,10 +32,21 @@ class FlashCausalLMBatch(Batch):
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
# cumulative sequence lengths
cu_seqlens
:
torch
.
Tensor
# cumulative query sequence lengths, only used in decode
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
# Indices to copy present to the correct indices is the pre-allocated past key values
past_present_indices
:
torch
.
Tensor
# tensor of length b holding starting offset of each sequence
start_seq
:
torch
.
Tensor
# tensor of length b holding ending offset of each sequence
end_seq
:
torch
.
Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill
:
Optional
[
torch
.
Tensor
]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill
:
Optional
[
torch
.
Tensor
]
# tensor of length b holding starting offset of each query sequence, only used in decode
start_seq_q
:
Optional
[
torch
.
Tensor
]
# tensor of length b holding ending offset of each query sequence, only used in decode
end_seq_q
:
Optional
[
torch
.
Tensor
]
# past key values, only used in decode
past_key_values
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
int
...
...
@@ -90,7 +99,11 @@ class FlashCausalLMBatch(Batch):
)[
"input_ids"
]
position_ids
=
[]
cu_seqlens
=
[
0
]
past_present_indices
=
[]
start_seq
=
[]
end_seq
=
[]
start_seq_prefill
=
[]
end_seq_prefill
=
[]
max_seqlen
=
0
input_lengths
=
[]
...
...
@@ -110,9 +123,9 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_length
=
0
cumulative_max_length
=
0
prefill_out_cumulative_length
=
0
max_tokens
=
0
max_length
=
0
# Parse batch
...
...
@@ -138,7 +151,10 @@ class FlashCausalLMBatch(Batch):
position_ids
.
append
(
request_position_ids
)
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
start_seq_prefill
.
append
(
cumulative_length
)
end_seq_prefill
.
append
(
cumulative_length
+
input_length
)
start_seq
.
append
(
cumulative_max_length
)
end_seq
.
append
(
cumulative_max_length
+
input_length
)
next_token_chooser_parameters
.
append
(
r
.
parameters
)
...
...
@@ -168,9 +184,17 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
1
)
prefill_out_cumulative_length
+=
1
request_past_present_indices
=
torch
.
arange
(
cumulative_max_length
,
cumulative_max_length
+
input_length
,
dtype
=
torch
.
int64
,
)
past_present_indices
.
append
(
request_past_present_indices
)
# Update
# Remove one as the first token des not have a past
cumulative_length
+=
input_length
max_tokens
+=
input_length
+
max_new_tokens
cumulative_max_length
+=
input_length
+
max_new_tokens
-
1
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
...
...
@@ -184,26 +208,45 @@ class FlashCausalLMBatch(Batch):
for
i
,
input_ids
in
enumerate
(
all_input_ids
):
all_input_ids_tensor
[
i
,
:
len
(
input_ids
)]
=
input_ids
# Create tensors on device
all_input_ids_tensor
=
torch
.
tensor
(
all_input_ids_tensor
,
dtype
=
torch
.
int64
,
device
=
device
)
start_seq
=
torch
.
tensor
(
start_seq
,
device
=
device
,
dtype
=
torch
.
int32
)
end_seq
=
torch
.
tensor
(
end_seq
,
device
=
device
,
dtype
=
torch
.
int32
)
if
len
(
pb
.
requests
)
>
1
:
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
past_present_indices
=
np
.
concatenate
(
past_present_indices
,
dtype
=
np
.
int64
)
start_seq_prefill
=
torch
.
tensor
(
start_seq_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
end_seq_prefill
=
torch
.
tensor
(
end_seq_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
# Create tensors on device
past_present_indices
=
past_present_indices
[
0
]
start_seq_prefill
=
start_seq
end_seq_prefill
=
end_seq
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
all_input_ids_tensor
=
torch
.
tensor
(
all_input_ids_tensor
,
dtype
=
torch
.
int64
,
device
=
device
)
position_ids
=
torch
.
tensor
(
position_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
device
=
device
,
dtype
=
torch
.
int32
)
past_present_indices
=
torch
.
tensor
(
past_present_indices
,
device
=
device
,
dtype
=
torch
.
int64
)
if
all_prefill_logprobs
:
prefill_head_indices
=
None
prefill_next_token_indices
=
cu
_seq
lens
[
1
:]
-
1
prefill_next_token_indices
=
end
_seq
_prefill
-
1
elif
no_prefill_logprobs
:
prefill_head_indices
=
cu
_seq
lens
[
1
:]
-
1
prefill_head_indices
=
end
_seq
_prefill
-
1
prefill_next_token_indices
=
None
else
:
prefill_head_indices
=
torch
.
tensor
(
...
...
@@ -219,8 +262,13 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
None
,
past_present_indices
=
past_present_indices
,
start_seq
=
start_seq
,
end_seq
=
end_seq
,
start_seq_prefill
=
start_seq_prefill
,
end_seq_prefill
=
end_seq_prefill
,
start_seq_q
=
None
,
end_seq_q
=
None
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
...
...
@@ -233,7 +281,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_chooser
=
next_token_chooser
,
stopping_criterias
=
stopping_criterias
,
max_tokens
=
max_tokens
,
max_tokens
=
cumulative_max_length
,
)
@
tracer
.
start_as_current_span
(
"filter"
)
...
...
@@ -244,10 +292,10 @@ class FlashCausalLMBatch(Batch):
if
len
(
request_ids
)
==
len
(
self
):
return
self
single_request
=
len
(
request_ids
)
==
1
device
=
self
.
input_ids
.
device
# Cumulative length
cumulative_length
=
0
cumulative_
max_
length
=
0
# New values after filtering
requests_idx_mapping
=
{}
...
...
@@ -255,11 +303,17 @@ class FlashCausalLMBatch(Batch):
# Used to index into tensors
indices
=
[]
# past indices to keep
past_indices
=
torch
.
zeros
(
self
.
past_key_values
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
device
)
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens
=
torch
.
zeros
(
len
(
request_ids
)
+
1
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
self
.
cu_seqlens_q
[:
len
(
request_ids
)
+
1
]
start_seq
=
torch
.
empty
(
len
(
request_ids
),
dtype
=
torch
.
int32
)
end_seq
=
torch
.
empty
(
len
(
request_ids
),
dtype
=
torch
.
int32
)
start_seq_q
=
self
.
start_seq_q
[:
len
(
request_ids
)]
end_seq_q
=
self
.
end_seq_q
[:
len
(
request_ids
)]
max_seqlen
=
0
past_key_values
=
[]
requests
=
[]
all_input_ids
=
[]
...
...
@@ -270,8 +324,6 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
[]
max_tokens
=
0
for
i
,
request_id
in
enumerate
(
request_ids
):
idx
=
self
.
requests_idx_mapping
[
request_id
]
indices
.
append
(
idx
)
...
...
@@ -281,16 +333,8 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length
=
self
.
input_lengths
[
idx
]
# Copy to tensor (CPU)
cu_seqlens
[
i
+
1
]
=
cumulative_length
+
request_input_length
max_seqlen
=
max
(
max_seqlen
,
request_input_length
)
# Slice from past
past_key_values
.
append
(
self
.
past_key_values
[:,
self
.
cu_seqlens
[
idx
]
:
self
.
cu_seqlens
[
idx
+
1
]]
)
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
input_lengths
.
append
(
request_input_length
)
...
...
@@ -300,39 +344,32 @@ class FlashCausalLMBatch(Batch):
stopping_criteria
=
self
.
stopping_criterias
[
idx
]
stopping_criterias
.
append
(
stopping_criteria
)
cumulative_length
+=
request_input_length
max_tokens
+=
request_input_length
+
(
remaining_tokens
=
(
stopping_criteria
.
max_new_tokens
-
stopping_criteria
.
current_tokens
)
if
single_request
:
# Preallocate tensor for bs = 1 case
past_key_values
=
F
.
pad
(
past_key_values
[
0
],
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
0
].
current_tokens
,
),
)
else
:
# Cat all past
past_key_values
=
torch
.
cat
(
past_key_values
,
dim
=
1
)
# Copy to tensor (CPU)
start_seq
[
i
]
=
cumulative_max_length
end_seq
[
i
]
=
cumulative_max_length
+
request_input_length
# Set slice
past_indices
[
self
.
start_seq
[
idx
]
:
self
.
end_seq
[
idx
]
+
remaining_tokens
-
1
]
=
True
cumulative_max_length
+=
request_input_length
+
remaining_tokens
-
1
# Index into tensors
input_ids
=
self
.
input_ids
[
indices
]
position_ids
=
self
.
position_ids
[
indices
]
all_input_ids_tensor
=
self
.
all_input_ids_tensor
[
indices
]
next_token_chooser
=
self
.
next_token_chooser
.
filter
(
indices
)
past_key_values
=
self
.
past_key_values
[
past_indices
]
# Move to GPU now that we have the whole tensor
cu_seqlens
=
cu_seqlens
.
to
(
self
.
cu_seqlens
.
device
)
start_seq
=
start_seq
.
to
(
device
)
end_seq
=
end_seq
.
to
(
device
)
past_present_indices
=
end_seq
-
1
return
FlashCausalLMBatch
(
batch_id
=
self
.
batch_id
,
...
...
@@ -340,8 +377,13 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
past_present_indices
=
past_present_indices
,
start_seq
=
start_seq
,
end_seq
=
end_seq
,
start_seq_prefill
=
None
,
end_seq_prefill
=
None
,
start_seq_q
=
start_seq_q
,
end_seq_q
=
end_seq_q
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
None
,
prefill_next_token_indices
=
None
,
...
...
@@ -354,7 +396,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_chooser
=
next_token_chooser
,
stopping_criterias
=
stopping_criterias
,
max_tokens
=
max_tokens
,
max_tokens
=
cumulative_max_length
,
)
@
classmethod
...
...
@@ -371,10 +413,12 @@ class FlashCausalLMBatch(Batch):
input_ids
=
batches
[
0
].
input_ids
.
new_empty
(
total_batch_size
)
position_ids
=
batches
[
0
].
position_ids
.
new_empty
(
total_batch_size
)
cu_seqlens
=
[
0
]
cu_seqlens_q
=
torch
.
arange
(
0
,
total_batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
start_seq
=
batches
[
0
].
start_seq
.
new_empty
(
total_batch_size
)
end_seq
=
batches
[
0
].
end_seq
.
new_empty
(
total_batch_size
)
start_seq_q
=
torch
.
arange
(
0
,
total_batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
end_seq_q
=
start_seq_q
+
1
max_seqlen
=
0
past_key_values
=
[]
...
...
@@ -389,7 +433,6 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size
=
0
cumulative_length
=
0
max_tokens
=
0
max_length
=
0
...
...
@@ -410,18 +453,10 @@ class FlashCausalLMBatch(Batch):
input_ids
[
start_index
:
end_index
]
=
batch
.
input_ids
position_ids
[
start_index
:
end_index
]
=
batch
.
position_ids
# Add cumulative lengths of all previous inputs
cu_seqlens
.
extend
([
l
+
cumulative_length
for
l
in
batch
.
cu_seqlens
[
1
:]])
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
start_seq
[
start_index
:
end_index
]
=
batch
.
start_seq
+
max_tokens
end_seq
[
start_index
:
end_index
]
=
batch
.
end_seq
+
max_tokens
if
len
(
batch
)
!=
1
:
past_key_values
.
append
(
batch
.
past_key_values
)
else
:
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values
.
append
(
batch
.
past_key_values
[:,
:
batch
.
input_lengths
[
0
]]
)
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
...
...
@@ -431,9 +466,9 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters
.
extend
([
r
.
parameters
for
r
in
batch
.
requests
])
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
past_key_values
.
append
(
batch
.
past_key_values
)
# Update
cumulative_length
+=
batch
.
cu_seqlens
[
-
1
]
cumulative_batch_size
+=
len
(
batch
)
max_tokens
+=
batch
.
max_tokens
max_length
=
max
(
...
...
@@ -448,6 +483,9 @@ class FlashCausalLMBatch(Batch):
),
)
past_key_values
=
torch
.
cat
(
past_key_values
,
dim
=
0
)
past_present_indices
=
end_seq
-
1
all_input_ids_tensor
=
torch
.
zeros
(
(
total_batch_size
,
max_length
),
dtype
=
torch
.
int64
,
device
=
device
)
...
...
@@ -463,11 +501,6 @@ class FlashCausalLMBatch(Batch):
cumulative_batch_size
+=
len
(
batch
)
# Cat past
past_key_values
=
torch
.
cat
(
past_key_values
,
dim
=
1
)
# Create final tensor on GPU
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
,
device
=
device
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
next_token_chooser_parameters
,
dtype
=
dtype
,
device
=
device
)
...
...
@@ -478,8 +511,13 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
past_present_indices
=
past_present_indices
,
start_seq
=
start_seq
,
end_seq
=
end_seq
,
start_seq_prefill
=
None
,
end_seq_prefill
=
None
,
start_seq_q
=
start_seq_q
,
end_seq_q
=
end_seq_q
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
None
,
prefill_next_token_indices
=
None
,
...
...
@@ -550,9 +588,12 @@ class FlashCausalLM(Model):
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
],
start_seq
:
torch
.
Tensor
,
end_seq
:
torch
.
Tensor
,
start_seq_q
:
Optional
[
torch
.
Tensor
],
end_seq_q
:
Optional
[
torch
.
Tensor
],
max_s
:
int
,
past_present_indices
:
torch
.
Tensor
,
past_key_values
:
Optional
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -561,9 +602,12 @@ class FlashCausalLM(Model):
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
start_seq
=
start_seq
,
end_seq
=
end_seq
,
start_seq_q
=
start_seq_q
,
end_seq_q
=
end_seq_q
,
max_s
=
max_s
,
past_present_indices
=
past_present_indices
,
past_key_values
=
past_key_values
,
pre_allocate_past_size
=
pre_allocate_past_size
,
lm_head_indices
=
lm_head_indices
,
...
...
@@ -575,23 +619,27 @@ class FlashCausalLM(Model):
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
prefill
=
batch
.
past_key_values
is
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
single_request
=
len
(
batch
)
==
1
if
prefill
and
single_request
:
if
prefill
:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size
=
(
batch
.
input_lengths
[
0
]
+
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
# ==
Sum over batch size (
number of tokens + max_new_tokens
) - batch size
pre_allocate_past_size
=
batch
.
max_tokens
start_seq
=
batch
.
start_seq_prefill
end_seq
=
batch
.
end_seq_prefill
else
:
pre_allocate_past_size
=
None
start_seq
=
batch
.
start_seq
end_seq
=
batch
.
end_seq
out
,
present
=
self
.
forward
(
batch
.
input_ids
,
batch
.
position_ids
,
batch
.
cu_seqlens
,
batch
.
cu_seqlens_q
,
start_seq
,
end_seq
,
batch
.
start_seq_q
,
batch
.
end_seq_q
,
batch
.
max_seqlen
,
batch
.
past_present_indices
,
batch
.
past_key_values
,
pre_allocate_past_size
,
batch
.
prefill_head_indices
,
...
...
@@ -614,55 +662,19 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices
=
batch
.
input_ids
.
new_zeros
(
len
(
out
))
# Create batch.
cu_seqlens
_q for decode
batch
.
cu_seqlens
_q
=
torch
.
arange
(
0
,
len
(
batch
)
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
# Create batch.
start_seq_q and batch.end_seq
_q for decode
batch
.
start_seq
_q
=
torch
.
arange
(
0
,
len
(
batch
),
device
=
self
.
device
,
dtype
=
torch
.
int32
)
batch
.
end_seq_q
=
batch
.
start_seq_q
+
1
next_position_ids
=
batch
.
position_ids
.
new_empty
(
len
(
batch
))
# We do not need start_seq_prefill and end_seq_prefill anymore
batch
.
start_seq_prefill
=
None
batch
.
end_seq_prefill
=
None
else
:
prefill_logprobs
=
None
next_position_ids
=
batch
.
position_ids
# Prepare past for next decode
if
len
(
batch
)
>
1
:
# Used to slice next batch past
past_indices
=
torch
.
empty
(
present
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
self
.
device
)
batch
.
past_key_values
=
present
.
new_empty
(
(
present
.
shape
[
0
],
present
.
shape
[
1
]
+
len
(
batch
.
requests
),
*
present
.
shape
[
2
:],
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length
=
0
for
i
,
input_length
in
enumerate
(
batch
.
input_lengths
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
# Indices to copy present at the correct place in past_key_values
torch
.
arange
(
start_index
+
i
,
end_index
+
i
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
out
=
past_indices
[
start_index
:
end_index
],
)
cumulative_length
+=
input_length
# Copy from present to past_key_values
batch
.
past_key_values
[:,
past_indices
]
=
present
# Initialize past_key_values in prefill for len(batch) == 1
elif
prefill
:
# present is already pre-padded
batch
.
past_key_values
=
present
# Cumulative length
cumulative_length
=
0
...
...
@@ -685,6 +697,7 @@ class FlashCausalLM(Model):
input_length
,
all_input_ids
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
...
...
@@ -718,7 +731,8 @@ class FlashCausalLM(Model):
# Set values in batch
batch
.
input_ids
=
next_input_ids
batch
.
position_ids
=
next_position_ids
+
1
batch
.
cu_seqlens
=
batch
.
cu_seqlens
+
batch
.
cu_seqlens_q
batch
.
past_present_indices
=
batch
.
end_seq
batch
.
end_seq
=
batch
.
end_seq
+
1
if
prefill
and
prefill_logprobs
:
# Get prefill logprobs
...
...
@@ -843,6 +857,7 @@ class FlashCausalLM(Model):
batch
.
prefill_head_indices
=
None
batch
.
prefill_next_token_indices
=
None
batch
.
max_seqlen
=
batch
.
max_seqlen
+
1
batch
.
past_key_values
=
present
# No need to return a batch if we know that all requests stopped
return
generations
,
batch
if
not
stopped
else
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment