Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ad66f6ef
Unverified
Commit
ad66f6ef
authored
May 09, 2023
by
OlivierDehaene
Committed by
GitHub
May 09, 2023
Browse files
feat(server): optim flash causal lm decode_token (#285)
parent
bc5c0723
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
258 additions
and
140 deletions
+258
-140
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+3
-4
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+3
-4
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+3
-4
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+244
-123
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+2
-2
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+1
-1
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+2
-2
No files found.
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
ad66f6ef
...
...
@@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module):
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
slice_past_index
=
None
# Get rotary cos and sin for this forward
...
...
@@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
ad66f6ef
...
...
@@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
slice_past_index
=
None
# Get rotary cos and sin for this forward
...
...
@@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
ad66f6ef
...
...
@@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module):
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
slice_past_index
=
None
residual
=
None
...
...
@@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
ad66f6ef
import
torch
import
torch.distributed
import
numpy
as
np
from
torch.nn
import
functional
as
F
from
dataclasses
import
dataclass
...
...
@@ -33,12 +35,16 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping
:
Dict
[
int
,
int
]
# Decoder values
input_ids
:
List
[
torch
.
Tensor
]
position_ids
:
List
[
torch
.
Tensor
]
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
# cumulative sequence lengths
cu_seqlens
:
List
[
int
]
cu_seqlens
:
torch
.
Tensor
# cumulative query sequence lengths, only used in decode
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
# past key values, only used in decode
past_key_values
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
int
past_key_values
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
# All tokens
all_input_ids
:
List
[
List
[
int
]]
...
...
@@ -53,9 +59,6 @@ class FlashCausalLMBatch(Batch):
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
# Constant shared tensor, ref here just so that it's accessible in concatentate()
past_pad
:
Optional
[
torch
.
Tensor
]
# Maximum number of tokens this batch will grow to
max_tokens
:
int
...
...
@@ -74,7 +77,6 @@ class FlashCausalLMBatch(Batch):
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
0
]
max_seqlen
=
0
...
...
@@ -83,7 +85,6 @@ class FlashCausalLMBatch(Batch):
offsets
=
[]
token_offsets
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
requests_idx_mapping
=
{}
next_token_choosers
=
[]
...
...
@@ -109,15 +110,11 @@ class FlashCausalLMBatch(Batch):
offsets
.
append
(
None
)
token_offsets
.
append
(
None
)
all_input_ids
.
append
(
tokenized_input
)
tokenized_input
=
torch
.
tensor
(
tokenized_input
,
device
=
device
)
input_ids
.
append
(
tokenized_input
)
all_input_ids
.
append
(
tokenized_input
)
# Position ids
position_ids
.
append
(
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
,
device
=
device
)
)
position_ids
.
append
(
np
.
arange
(
0
,
input_length
))
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
...
...
@@ -130,14 +127,19 @@ class FlashCausalLMBatch(Batch):
max_new_tokens
=
stopping_criteria
.
max_new_tokens
stopping_criterias
.
append
(
stopping_criteria
)
all_input_ids_tensor
.
append
(
F
.
pad
(
tokenized_input
,
(
0
,
stopping_criteria
.
max_new_tokens
))
)
# Update
cumulative_length
+=
input_length
max_tokens
+=
input_length
+
max_new_tokens
# Create tensors on device
input_ids
=
torch
.
tensor
(
np
.
concatenate
(
all_input_ids
),
dtype
=
torch
.
int64
,
device
=
device
)
position_ids
=
torch
.
tensor
(
np
.
concatenate
(
position_ids
),
dtype
=
torch
.
int32
,
device
=
device
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
device
=
device
,
dtype
=
torch
.
int32
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
...
...
@@ -145,16 +147,16 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
None
,
max_seqlen
=
max_seqlen
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
all_input_ids_tensor
=
[]
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
past_pad
=
None
,
max_tokens
=
max_tokens
,
)
...
...
@@ -174,9 +176,13 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping
=
{}
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
0
]
input_ids
=
self
.
input_ids
.
new_empty
(
len
(
requests
))
position_ids
=
self
.
position_ids
.
new_empty
(
len
(
requests
))
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens
=
torch
.
zeros
(
len
(
requests
)
+
1
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
arange
(
0
,
len
(
requests
)
+
1
,
device
=
self
.
cu_seqlens_q
.
device
,
dtype
=
torch
.
int32
)
max_seqlen
=
0
past_key_values
=
[]
...
...
@@ -199,16 +205,18 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length
=
self
.
input_lengths
[
idx
]
input_ids
.
append
(
self
.
input_ids
[
idx
])
position_ids
.
append
(
self
.
position_ids
[
idx
])
cu_seqlens
.
append
(
cumulative_length
+
request_input_length
)
# Copy tensors (GPU)
input_ids
[
i
]
=
self
.
input_ids
[
idx
]
position_ids
[
i
]
=
self
.
position_ids
[
idx
]
# Copy to tensor (CPU)
cu_seqlens
[
i
+
1
]
=
cumulative_length
+
request_input_length
max_seqlen
=
max
(
max_seqlen
,
request_input_length
)
# True index for past
past_key_values
.
append
(
self
.
past_key_values
[
2
*
idx
])
if
not
single_request
:
# Add one padding
past_key_values
.
append
(
self
.
past_pad
)
# Slice from past
past_key_values
.
append
(
self
.
past_key_values
[:,
self
.
cu_seqlens
[
idx
]
:
self
.
cu_seqlens
[
idx
+
1
]]
)
all_input_ids
.
append
(
self
.
all_input_ids
[
idx
])
all_input_ids_tensor
.
append
(
self
.
all_input_ids_tensor
[
idx
])
...
...
@@ -229,7 +237,7 @@ class FlashCausalLMBatch(Batch):
if
single_request
:
# Preallocate tensor for bs = 1 case
past_key_values
=
torch
.
nn
.
functional
.
pad
(
past_key_values
=
F
.
pad
(
past_key_values
[
0
],
(
0
,
...
...
@@ -243,15 +251,21 @@ class FlashCausalLMBatch(Batch):
-
stopping_criterias
[
0
].
current_tokens
,
),
)
else
:
# Cat all past
past_key_values
=
torch
.
cat
(
past_key_values
,
dim
=
1
)
# Move to GPU now that we have the whole tensor
cu_seqlens
=
cu_seqlens
.
to
(
self
.
cu_seqlens
.
device
)
return
FlashCausalLMBatch
(
batch_id
=
self
.
batch_id
,
past_pad
=
self
.
past_pad
,
requests
=
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen
=
max_seqlen
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
...
...
@@ -271,9 +285,16 @@ class FlashCausalLMBatch(Batch):
requests
=
[]
requests_idx_mapping
=
{}
input_ids
=
[]
position_ids
=
[]
total_batch_size
=
sum
([
len
(
b
)
for
b
in
batches
])
device
=
batches
[
0
].
input_ids
.
device
input_ids
=
batches
[
0
].
input_ids
.
new_empty
(
total_batch_size
)
position_ids
=
batches
[
0
].
position_ids
.
new_empty
(
total_batch_size
)
cu_seqlens
=
[
0
]
cu_seqlens_q
=
torch
.
arange
(
0
,
total_batch_size
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
max_seqlen
=
0
past_key_values
=
[]
...
...
@@ -302,22 +323,25 @@ class FlashCausalLMBatch(Batch):
for
k
,
v
in
batch
.
requests_idx_mapping
.
items
():
requests_idx_mapping
[
k
]
=
v
+
cumulative_batch_size
input_ids
.
extend
(
batch
.
input_ids
)
position_ids
.
extend
(
batch
.
position_ids
)
start_index
=
cumulative_batch_size
end_index
=
cumulative_batch_size
+
len
(
batch
)
# Copy tensors (GPU)
input_ids
[
start_index
:
end_index
]
=
batch
.
input_ids
position_ids
[
start_index
:
end_index
]
=
batch
.
position_ids
# Add cumulative lengths of all previous inputs
cu_seqlens
.
extend
([
l
+
cumulative_length
for
l
in
batch
.
cu_seqlens
[
1
:]])
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
if
len
(
batch
)
!=
1
:
past_key_values
.
ext
end
(
batch
.
past_key_values
)
past_key_values
.
app
end
(
batch
.
past_key_values
)
else
:
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values
.
append
(
batch
.
past_key_values
[:,
:
batch
.
input_lengths
[
0
]]
)
# Add one padding
past_key_values
.
append
(
batch
.
past_pad
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
...
...
@@ -334,14 +358,19 @@ class FlashCausalLMBatch(Batch):
cumulative_batch_size
+=
len
(
batch
)
max_tokens
+=
batch
.
max_tokens
# Cat past
past_key_values
=
torch
.
cat
(
past_key_values
,
dim
=
1
)
# Create final tensor on GPU
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
,
device
=
device
)
return
FlashCausalLMBatch
(
batch_id
=
batches
[
0
].
batch_id
,
past_pad
=
batches
[
0
].
past_pad
,
requests
=
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen
=
max_seqlen
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
...
...
@@ -367,10 +396,9 @@ class FlashCausalLM(Model):
quantize
:
bool
=
False
,
decode_buffer
:
int
=
3
,
):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashCausalLM is only available on GPU"
)
...
...
@@ -410,6 +438,7 @@ class FlashCausalLM(Model):
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
],
max_s
:
int
,
past_key_values
:
Optional
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -419,6 +448,7 @@ class FlashCausalLM(Model):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
pre_allocate_past_size
=
pre_allocate_past_size
,
...
...
@@ -428,22 +458,9 @@ class FlashCausalLM(Model):
def
generate_token
(
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
# Shortcut when batch_size == 1
if
len
(
batch
)
==
1
:
input_ids
=
batch
.
input_ids
[
0
].
view
(
-
1
)
# No need to slice as flash attention will take care of it with cu_seqlens
past_key_values
=
batch
.
past_key_values
else
:
# Concatenate tensors
input_ids
=
torch
.
cat
(
batch
.
input_ids
).
view
(
-
1
)
past_key_values
=
(
torch
.
cat
(
batch
.
past_key_values
,
dim
=
1
)
if
batch
.
past_key_values
is
not
None
else
None
)
prefill
=
batch
.
past_key_values
is
None
# if prefill and bs == 1
if
past_key_values
is
None
and
len
(
batch
)
==
1
:
if
prefill
and
len
(
batch
)
==
1
:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size
=
(
...
...
@@ -452,42 +469,74 @@ class FlashCausalLM(Model):
else
:
pre_allocate_past_size
=
None
# Concatenate when prefill, torch.tensor when decode
position_ids
=
(
torch
.
tensor
(
batch
.
position_ids
,
device
=
self
.
device
)
if
batch
.
past_key_values
is
not
None
else
torch
.
cat
(
batch
.
position_ids
)
)
cu_seqlens
=
torch
.
tensor
(
batch
.
cu_seqlens
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
out
,
present
=
self
.
forward
(
input_ids
,
position_ids
,
cu_seqlens
,
batch
.
input_ids
,
batch
.
position_ids
,
batch
.
cu_seqlens
,
batch
.
cu_seqlens_q
,
batch
.
max_seqlen
,
past_key_values
,
batch
.
past_key_values
,
pre_allocate_past_size
,
)
# Initialize past_key_values in prefill
if
batch
.
past_key_values
is
None
:
# Initialize past padding tensor
if
self
.
past_pad
is
None
:
self
.
past_pad
=
present
.
new_zeros
(
present
.
shape
[
0
],
1
,
*
present
.
shape
[
2
:]
if
prefill
:
if
len
(
batch
)
>
1
:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices
=
batch
.
input_ids
.
new_zeros
(
len
(
batch
.
input_ids
))
# Create batch.cu_seqlens_q for decode
batch
.
cu_seqlens_q
=
torch
.
arange
(
0
,
len
(
batch
)
+
1
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
next_input_ids
=
batch
.
input_ids
.
new_empty
(
len
(
batch
))
next_position_ids
=
batch
.
position_ids
.
new_empty
(
len
(
batch
))
else
:
prefill_logprobs
=
None
next_input_ids
=
batch
.
input_ids
next_position_ids
=
batch
.
position_ids
next_token_logprobs
=
out
.
new_empty
(
len
(
batch
))
# Prepare past for next decode
if
len
(
batch
)
>
1
:
# Used to slice next batch past
past_indices
=
torch
.
empty
(
present
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
self
.
device
)
batch
.
past_key_values
=
present
.
new_empty
(
(
present
.
shape
[
0
],
present
.
shape
[
1
]
+
len
(
batch
.
requests
),
*
present
.
shape
[
2
:],
)
# Set in batch in case it needs to be used later in concatenate()
batch
.
past_pad
=
self
.
past_pad
if
len
(
batch
)
==
1
:
# present is already pre-padded
batch
.
past_key_values
=
present
else
:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward
batch
.
past_key_values
=
[
None
,
self
.
past_pad
]
*
len
(
batch
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length
=
0
for
i
,
input_length
in
enumerate
(
batch
.
input_lengths
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
# Indices to copy present at the correct place in past_key_values
torch
.
arange
(
start_index
+
i
,
end_index
+
i
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
out
=
past_indices
[
start_index
:
end_index
],
)
cumulative_length
+=
input_length
# Copy from present to past_key_values
batch
.
past_key_values
[:,
past_indices
]
=
present
# Initialize past_key_values in prefill for len(batch) == 1
elif
prefill
:
# present is already pre-padded
batch
.
past_key_values
=
present
# Cumulative length
cumulative_length
=
0
...
...
@@ -498,54 +547,134 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
offsets
,
batch
.
token_offsets
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids_tensor
,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
for
i
,
(
request
,
input_length
,
offset
,
token_offset
,
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
all_input_ids_tensor
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
prefill
=
stopping_criteria
.
current_tokens
==
0
if
prefill
:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits
=
out
[
start_index
:
end_index
]
# only take last token logit
logits
=
out
[
end_index
-
1
:
end_index
]
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor
=
batch
.
input_ids
.
new_empty
(
input_length
+
stopping_criteria
.
max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor
[:
input_length
]
=
batch
.
input_ids
[
start_index
:
end_index
]
batch
.
all_input_ids_tensor
.
append
(
all_input_ids_tensor
)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids
[
i
]
=
batch
.
position_ids
[
end_index
-
1
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if
len
(
batch
)
>
1
:
prefill_tokens_indices
[
start_index
:
end_index
-
1
]
=
batch
.
input_ids
[
start_index
+
1
:
end_index
]
else
:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices
=
batch
.
input_ids
[
start_index
+
1
:
end_index
]
else
:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits
=
out
[
i
].
unsqueeze
(
0
)
logits
=
out
[
i
].
view
(
1
,
-
1
)
all_input_ids_tensor
=
batch
.
all_input_ids_tensor
[
i
]
# Select next token
next_token_id
,
logprob
s
=
next_token_chooser
(
next_token_id
,
logprob
=
next_token_chooser
(
all_input_ids_tensor
[
None
,
:
input_length
],
logits
)
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_item
=
next_token_id_squeezed
.
item
()
# Add to all_input_ids_tensor
next_token_id_squeezed
=
next_token_id
.
view
(
1
)
all_input_ids_tensor
[
input_length
]
=
next_token_id_squeezed
# Set values
next_input_ids
[
i
]
=
next_token_id_squeezed
next_token_logprobs
[
i
]
=
logprob
[
-
1
,
next_token_id
].
view
(
1
)
cumulative_length
+=
input_length
# Set values in batch
batch
.
input_ids
=
next_input_ids
batch
.
position_ids
=
next_position_ids
+
1
batch
.
cu_seqlens
=
batch
.
cu_seqlens
+
batch
.
cu_seqlens_q
if
prefill
:
# Get prefill logprobs
prefill_logprobs_tensor
=
torch
.
log_softmax
(
out
,
-
1
)
prefill_logprobs
=
torch
.
gather
(
prefill_logprobs_tensor
,
1
,
prefill_tokens_indices
.
view
(
-
1
,
1
)
)
# GPU <-> CPU sync
prefill_logprobs
=
prefill_logprobs
.
view
(
-
1
).
tolist
()
# GPU <-> CPU sync
next_token_logprobs
=
next_token_logprobs
.
tolist
()
next_token_ids
=
batch
.
input_ids
.
tolist
()
cumulative_length
=
0
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
offsets
,
batch
.
token_offsets
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids_tensor
,
next_token_ids
,
next_token_logprobs
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
offset
,
token_offset
,
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
all_input_ids_tensor
,
next_token_id
,
next_token_logprob
,
)
in
enumerate
(
iterator
):
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id_item
)
all_input_ids_tensor
[
input_length
]
=
next_token_id_item
all_input_ids
.
append
(
next_token_id
)
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id_item
]
next_token_text
,
offset
,
token_offset
=
self
.
decode_token
(
all_input_ids
,
offset
,
...
...
@@ -554,7 +683,7 @@ class FlashCausalLM(Model):
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token_id
_item
,
next_token_id
,
next_token_text
,
)
...
...
@@ -579,9 +708,9 @@ class FlashCausalLM(Model):
# Prefill
if
prefill
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
logprobs
.
gather
(
1
,
all_input_ids_tensor
[
1
:
input_length
].
unsqueeze
(
1
)
).
squeeze
(
1
)[:
-
1
].
tolist
()
request_
prefill_logprobs
=
[
float
(
"nan"
)]
+
prefill_logprobs
[
start_index
:
end_index
-
1
]
prefill_token_ids
=
all_input_ids
[:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
...
...
@@ -589,7 +718,7 @@ class FlashCausalLM(Model):
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
prefill_token_ids
,
request_
prefill_logprobs
,
prefill_texts
)
else
:
prefill_tokens
=
None
...
...
@@ -597,31 +726,23 @@ class FlashCausalLM(Model):
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id
_item
,
next_token_id
,
next_token_logprob
,
next_token_text
,
next_token_id
_item
in
self
.
all_special_ids
,
next_token_id
in
self
.
all_special_ids
,
generated_text
,
)
generations
.
append
(
generation
)
cumulative_length
+=
input_length
new_input_length
=
input_length
+
1
# Update values
batch
.
input_ids
[
i
]
=
next_token_id
batch
.
position_ids
[
i
]
=
input_length
batch
.
input_lengths
[
i
]
=
new_input_length
batch
.
offsets
[
i
]
=
offset
batch
.
token_offsets
[
i
]
=
token_offset
batch
.
all_input_ids
[
i
]
=
all_input_ids
batch
.
all_input_ids_tensor
[
i
]
=
all_input_ids_tensor
batch
.
max_seqlen
=
max
(
batch
.
max_seqlen
,
new_input_length
)
if
len
(
batch
)
!=
1
:
# Add each sequence before its padding
batch
.
past_key_values
[
i
*
2
]
=
present
[:,
start_index
:
end_index
]
# Cumulative sum
batch
.
cu_seqlens
[(
i
+
1
)]
=
batch
.
cu_seqlens
[
i
]
+
new_input_length
batch
.
max_seqlen
=
batch
.
max_seqlen
+
1
cumulative_length
+=
input_length
# No need to return a batch if we know that all requests stopped
return
generations
,
batch
if
not
stopped
else
None
server/text_generation_server/models/flash_llama.py
View file @
ad66f6ef
...
...
@@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
...
...
@@ -161,7 +161,7 @@ class FlashLlamaSharded(FlashLlama):
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
ad66f6ef
...
...
@@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX):
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
ad66f6ef
...
...
@@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoder is only available on GPU"
)
...
...
@@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder):
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment