Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
db4cb5e4
"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "2e544bd77afe019c4bb9d8c6882879c48d3ac65f"
Unverified
Commit
db4cb5e4
authored
Apr 21, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 21, 2023
Browse files
fix(server): fix past key values logic (#216)
@njhill fyi
parent
343437c7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
123 additions
and
20 deletions
+123
-20
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+24
-5
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+23
-4
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+24
-5
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+47
-6
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+2
-0
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+1
-0
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+2
-0
No files found.
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
db4cb5e4
...
@@ -25,6 +25,7 @@ from torch.nn import functional as F
...
@@ -25,6 +25,7 @@ from torch.nn import functional as F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
typing
import
Optional
# Flash attention imports
# Flash attention imports
import
rotary_emb
import
rotary_emb
...
@@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module):
position_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
):
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
@@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module):
past_key_values
=
hidden_states
.
new_empty
(
past_key_values
=
hidden_states
.
new_empty
(
(
(
len
(
self
.
layers
),
len
(
self
.
layers
),
len
(
hidden_states
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
2
,
self
.
num_heads
,
self
.
num_heads
,
self
.
head_size
,
self
.
head_size
,
...
@@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module):
)
)
layer_past_present_indices
=
None
layer_past_present_indices
=
None
cu_seqlens_q
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
# Decode
else
:
else
:
# Create indices from cumulative sequence lengths
# Create indices from cumulative sequence lengths
...
@@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module):
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
)
slice_past_index
=
None
# Get rotary cos and sin for this forward
# Get rotary cos and sin for this forward
# Avoid to index in each layer
# Avoid to index in each layer
...
@@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module):
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
hidden_states
,
hidden_states
,
residual
,
residual
,
...
@@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module):
sin
,
sin
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
[
i
]
,
layer_
past_key_values
,
layer_past_present_indices
,
layer_past_present_indices
,
cu_seqlens_q
,
cu_seqlens_q
,
)
)
...
@@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
...
@@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
):
hidden_states
,
present
=
self
.
model
(
hidden_states
,
present
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
)
)
logits
=
self
.
lm_head
(
hidden_states
)
logits
=
self
.
lm_head
(
hidden_states
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
db4cb5e4
...
@@ -27,6 +27,7 @@ from torch import nn
...
@@ -27,6 +27,7 @@ from torch import nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.models.gpt_neox
import
GPTNeoXConfig
from
transformers.models.gpt_neox
import
GPTNeoXConfig
from
typing
import
Optional
# Flash attention imports
# Flash attention imports
import
rotary_emb
import
rotary_emb
...
@@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
=
None
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
):
hidden_states
=
self
.
embed_in
(
input_ids
)
hidden_states
=
self
.
embed_in
(
input_ids
)
...
@@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
past_key_values
=
hidden_states
.
new_empty
(
past_key_values
=
hidden_states
.
new_empty
(
(
(
len
(
self
.
layers
),
len
(
self
.
layers
),
len
(
hidden_states
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
2
,
self
.
num_heads
,
self
.
num_heads
,
self
.
head_size
,
self
.
head_size
,
...
@@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
)
layer_past_present_indices
=
None
layer_past_present_indices
=
None
cu_seqlens_q
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
# Decode
else
:
else
:
# Create indices from cumulative sequence lengths
# Create indices from cumulative sequence lengths
...
@@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
)
slice_past_index
=
None
# Get rotary cos and sin for this forward
# Get rotary cos and sin for this forward
# Avoid to index in each layer
# Avoid to index in each layer
...
@@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
# We added padding that now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
hidden_states
,
hidden_states
,
residual
,
residual
,
...
@@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin
,
sin
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
[
i
]
,
layer_
past_key_values
,
layer_past_present_indices
,
layer_past_present_indices
,
cu_seqlens_q
,
cu_seqlens_q
,
)
)
...
@@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
...
@@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
):
hidden_states
,
present
=
self
.
gpt_neox
(
hidden_states
,
present
=
self
.
gpt_neox
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
)
)
logits
=
self
.
embed_out
(
hidden_states
)
logits
=
self
.
embed_out
(
hidden_states
)
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
db4cb5e4
...
@@ -5,6 +5,7 @@ import torch.nn.functional as F
...
@@ -5,6 +5,7 @@ import torch.nn.functional as F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
typing
import
Optional
# Flash attention imports
# Flash attention imports
import
flash_attn_cuda
import
flash_attn_cuda
...
@@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module):
...
@@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module):
position_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
if
self
.
tp_embeddings
:
if
self
.
tp_embeddings
:
...
@@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module):
...
@@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module):
past_key_values
=
hidden_states
.
new_empty
(
past_key_values
=
hidden_states
.
new_empty
(
(
(
len
(
self
.
h
),
len
(
self
.
h
),
len
(
hidden_states
),
len
(
hidden_states
)
if
pre_allocate_past_size
is
None
else
pre_allocate_past_size
,
2
,
2
,
1
,
1
,
self
.
head_size
,
self
.
head_size
,
...
@@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module):
...
@@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module):
)
)
layer_past_present_indices
=
None
layer_past_present_indices
=
None
cu_seqlens_q
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
# Decode
else
:
else
:
# Create indices from cumulative sequence lengths
# Create indices from cumulative sequence lengths
...
@@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module):
...
@@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module):
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
)
slice_past_index
=
None
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
for
i
,
layer
in
enumerate
(
self
.
h
):
# We added padding that now need to slice
layer_past_key_values
=
(
past_key_values
[
i
]
if
slice_past_index
is
None
else
past_key_values
[
i
,
:
slice_past_index
]
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
hidden_states
,
hidden_states
,
residual
,
residual
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
[
i
]
,
layer_
past_key_values
,
layer_past_present_indices
,
layer_past_present_indices
,
cu_seqlens_q
,
cu_seqlens_q
,
)
)
...
@@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
past_key_values
=
None
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
):
hidden_states
,
present
=
self
.
transformer
(
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
)
)
logits
=
self
.
lm_head
(
hidden_states
)
logits
=
self
.
lm_head
(
hidden_states
)
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
db4cb5e4
...
@@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor
=
all_input_ids_tensor
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
past_pad
=
None
,
)
)
@
tracer
.
start_as_current_span
(
"filter"
)
@
tracer
.
start_as_current_span
(
"filter"
)
...
@@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch):
...
@@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch):
cu_seqlens
.
append
(
cumulative_length
+
request_input_length
)
cu_seqlens
.
append
(
cumulative_length
+
request_input_length
)
max_seqlen
=
max
(
max_seqlen
,
request_input_length
)
max_seqlen
=
max
(
max_seqlen
,
request_input_length
)
if
not
single_request
:
if
not
single_request
:
# True index for past
past_key_values
.
append
(
self
.
past_key_values
[
2
*
idx
])
past_key_values
.
append
(
self
.
past_key_values
[
2
*
idx
])
past_key_values
.
append
(
self
.
past_key_values
[
1
])
# Add one padding
past_key_values
.
append
(
self
.
past_pad
)
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
all_input_ids_tensor
.
append
(
self
.
all_input_ids_tensor
[
idx
])
all_input_ids_tensor
.
append
(
self
.
all_input_ids_tensor
[
idx
])
...
@@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch):
...
@@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch):
# Preallocate tensor for bs = 1 case
# Preallocate tensor for bs = 1 case
past_key_values
=
torch
.
nn
.
functional
.
pad
(
past_key_values
=
torch
.
nn
.
functional
.
pad
(
self
.
past_key_values
[
0
],
self
.
past_key_values
[
0
],
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
0
].
current_tokens
)
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
0
].
current_tokens
,
),
)
)
return
FlashCausalLMBatch
(
return
FlashCausalLMBatch
(
...
@@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch):
...
@@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
# Add cumulative lengths of all previous inputs
cu_seqlens
.
extend
([
l
+
cumulative_length
for
l
in
batch
.
cu_seqlens
[
1
:]])
cu_seqlens
.
extend
([
l
+
cumulative_length
for
l
in
batch
.
cu_seqlens
[
1
:]])
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
if
len
(
batch
)
!=
1
:
if
len
(
batch
)
!=
1
:
past_key_values
.
extend
(
batch
.
past_key_values
)
past_key_values
.
extend
(
batch
.
past_key_values
)
else
:
else
:
past_key_values
.
append
(
batch
.
past_key_values
[:,
:
batch
.
input_lengths
[
0
]])
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values
.
append
(
batch
.
past_key_values
[:,
:
batch
.
input_lengths
[
0
]]
)
# Add one padding
past_key_values
.
append
(
batch
.
past_pad
)
past_key_values
.
append
(
batch
.
past_pad
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
...
@@ -366,6 +385,7 @@ class FlashCausalLM(Model):
...
@@ -366,6 +385,7 @@ class FlashCausalLM(Model):
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_s
:
int
,
max_s
:
int
,
past_key_values
:
Optional
=
None
,
past_key_values
:
Optional
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
# Model Forward
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
...
@@ -374,6 +394,7 @@ class FlashCausalLM(Model):
...
@@ -374,6 +394,7 @@ class FlashCausalLM(Model):
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_s
=
max_s
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
pre_allocate_past_size
=
pre_allocate_past_size
,
)
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
...
@@ -382,7 +403,9 @@ class FlashCausalLM(Model):
...
@@ -382,7 +403,9 @@ class FlashCausalLM(Model):
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
# Shortcut when batch_size == 1
# Shortcut when batch_size == 1
if
len
(
batch
)
==
1
:
if
len
(
batch
)
==
1
:
# No need to slice this down
input_ids
=
batch
.
input_ids
[
0
].
view
(
-
1
)
# Slice to remove extra padding
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
past_key_values
=
batch
.
past_key_values
past_key_values
=
batch
.
past_key_values
else
:
else
:
# Concatenate tensors
# Concatenate tensors
...
@@ -393,6 +416,16 @@ class FlashCausalLM(Model):
...
@@ -393,6 +416,16 @@ class FlashCausalLM(Model):
else
None
else
None
)
)
# if prefill and bs == 1
if
past_key_values
is
None
and
len
(
batch
)
==
1
:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size
=
(
batch
.
input_lengths
[
0
]
+
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
else
:
pre_allocate_past_size
=
None
# Concatenate when prefill, torch.tensor when decode
# Concatenate when prefill, torch.tensor when decode
position_ids
=
(
position_ids
=
(
torch
.
tensor
(
batch
.
position_ids
,
device
=
self
.
device
)
torch
.
tensor
(
batch
.
position_ids
,
device
=
self
.
device
)
...
@@ -409,21 +442,28 @@ class FlashCausalLM(Model):
...
@@ -409,21 +442,28 @@ class FlashCausalLM(Model):
cu_seqlens
,
cu_seqlens
,
batch
.
max_seqlen
,
batch
.
max_seqlen
,
past_key_values
,
past_key_values
,
pre_allocate_past_size
,
)
)
# Initialize past_key_values in prefill
# Initialize past_key_values in prefill
if
batch
.
past_key_values
is
None
:
if
batch
.
past_key_values
is
None
:
# Initialize past padding tensor
# Initialize past padding tensor
if
self
.
past_pad
is
None
:
if
self
.
past_pad
is
None
:
self
.
past_pad
=
present
.
new_zeros
(
present
.
shape
[
0
],
1
,
*
present
.
shape
[
2
:])
self
.
past_pad
=
present
.
new_zeros
(
present
.
shape
[
0
],
1
,
*
present
.
shape
[
2
:]
)
# Set in batch in case it needs to be used later in concatenate()
# Set in batch in case it needs to be used later in concatenate()
batch
.
past_pad
=
self
.
past_pad
batch
.
past_pad
=
self
.
past_pad
if
len
(
batch
)
==
1
:
if
len
(
batch
)
==
1
:
# Preallocate tensor for bs = 1 case
# Preallocate tensor for bs = 1 case
batch
.
past_key_values
=
torch
.
nn
.
functional
.
pad
(
batch
.
past_key_values
=
torch
.
nn
.
functional
.
pad
(
present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
batch
.
stopping_criterias
[
0
].
max_new_tokens
)
present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
batch
.
stopping_criterias
[
0
].
max_new_tokens
),
)
)
else
:
else
:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward
batch
.
past_key_values
=
[
None
,
self
.
past_pad
]
*
len
(
batch
)
batch
.
past_key_values
=
[
None
,
self
.
past_pad
]
*
len
(
batch
)
# Cumulative length
# Cumulative length
...
@@ -555,6 +595,7 @@ class FlashCausalLM(Model):
...
@@ -555,6 +595,7 @@ class FlashCausalLM(Model):
batch
.
all_input_ids_tensor
[
i
]
=
all_input_ids_tensor
batch
.
all_input_ids_tensor
[
i
]
=
all_input_ids_tensor
batch
.
max_seqlen
=
max
(
batch
.
max_seqlen
,
new_input_length
)
batch
.
max_seqlen
=
max
(
batch
.
max_seqlen
,
new_input_length
)
if
len
(
batch
)
!=
1
:
if
len
(
batch
)
!=
1
:
# Add each sequence before its padding
batch
.
past_key_values
[
i
*
2
]
=
present
[:,
start_index
:
end_index
]
batch
.
past_key_values
[
i
*
2
]
=
present
[:,
start_index
:
end_index
]
# Cumulative sum
# Cumulative sum
batch
.
cu_seqlens
[(
i
+
1
)]
=
batch
.
cu_seqlens
[
i
]
+
new_input_length
batch
.
cu_seqlens
[(
i
+
1
)]
=
batch
.
cu_seqlens
[
i
]
+
new_input_length
...
...
server/text_generation_server/models/flash_llama.py
View file @
db4cb5e4
...
@@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__)
...
@@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__)
class
FlashLlama
(
FlashCausalLM
):
class
FlashLlama
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
...
@@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
):
self
.
past_pad
=
None
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
server/text_generation_server/models/flash_neox.py
View file @
db4cb5e4
...
@@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
):
self
.
past_pad
=
None
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
db4cb5e4
...
@@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__)
...
@@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__)
class
FlashSantacoder
(
FlashCausalLM
):
class
FlashSantacoder
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
...
@@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
):
self
.
past_pad
=
None
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment