Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8aece3bd
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "1df6100c77d389cf5af602e1ae17d0bf865c45f4"
Unverified
Commit
8aece3bd
authored
Jun 05, 2024
by
OlivierDehaene
Committed by
GitHub
Jun 05, 2024
Browse files
feat: move allocation logic to rust (#1835)
Close #2007
parent
9ffe1f1e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
174 additions
and
613 deletions
+174
-613
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+163
-105
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+5
-492
server/text_generation_server/models/flash_qwen2.py
server/text_generation_server/models/flash_qwen2.py
+1
-4
server/text_generation_server/models/flash_starcoder2.py
server/text_generation_server/models/flash_starcoder2.py
+1
-4
server/text_generation_server/models/vlm_causal_lm.py
server/text_generation_server/models/vlm_causal_lm.py
+4
-8
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
8aece3bd
...
@@ -25,11 +25,6 @@ from text_generation_server.models.types import (
...
@@ -25,11 +25,6 @@ from text_generation_server.models.types import (
Generation
,
Generation
,
GeneratedText
,
GeneratedText
,
)
)
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
set_cache_manager
,
BLOCK_SIZE
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models.globals
import
MEM_POOL
,
CUDA_GRAPHS
from
text_generation_server.models.globals
import
MEM_POOL
,
CUDA_GRAPHS
import
text_generation_server.models.globals
as
tgi_globals
import
text_generation_server.models.globals
as
tgi_globals
...
@@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
...
@@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
BLOCK_SIZE
:
int
=
16
# Will be set in init
SLIDING_WINDOW
:
Optional
[
int
]
=
None
def
set_sliding_window
(
sliding_window
:
int
):
global
SLIDING_WINDOW
SLIDING_WINDOW
=
sliding_window
def
get_sliding_windows
()
->
int
:
global
SLIDING_WINDOW
return
SLIDING_WINDOW
@
dataclass
@
dataclass
class
FlashCausalLMBatch
(
Batch
):
class
FlashCausalLMBatch
(
Batch
):
...
@@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
...
@@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
# Decoder values
# Decoder values
input_ids
:
torch
.
Tensor
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
speculative_ids
:
torch
.
Tensor
speculative_ids
:
Optional
[
torch
.
Tensor
]
# Flash Attention values
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
]
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices
:
Optional
[
torch
.
Tensor
]
# Paged Attention values
# Paged Attention values
...
@@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
...
@@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
start_slots
:
torch
.
Tensor
start_slots
:
torch
.
Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices
:
torch
.
Tensor
slot_indices
:
torch
.
Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
# list of length b of list of length s_i // block_size
block_tables
:
Optional
[
List
[
List
[
int
]]
]
block_tables
:
List
[
List
[
int
]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor
:
Optional
[
torch
.
Tensor
]
block_tables_tensor
:
torch
.
Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots
:
Optional
[
torch
.
Tensor
]
slots
:
torch
.
Tensor
max_seqlen
:
int
max_seqlen
:
int
...
@@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor
:
torch
.
Tensor
top_n_tokens_tensor
:
torch
.
Tensor
# Number of blocks in this batch
# Number of blocks in this batch
blocks
:
int
num_
blocks
:
int
# Maximum number of blocks
# Maximum number of blocks
max_blocks
:
int
max_blocks
:
int
...
@@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
id
=
self
.
batch_id
,
id
=
self
.
batch_id
,
request_ids
=
[
r
.
id
for
r
in
self
.
requests
],
request_ids
=
[
r
.
id
for
r
in
self
.
requests
],
size
=
len
(
self
),
size
=
len
(
self
),
max_tokens
=
self
.
blocks
*
BLOCK_SIZE
,
max_tokens
=
self
.
num_
blocks
*
BLOCK_SIZE
,
)
)
@
classmethod
@
classmethod
...
@@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
...
@@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
)[
"input_ids"
]
)[
"input_ids"
]
return
batch_tokenized_inputs
return
batch_tokenized_inputs
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
batch_tokenized_inputs
=
cls
.
batch_tokenized_inputs
(
pb
.
requests
,
tokenizer
)
return
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
@
classmethod
@
classmethod
def
from_tokenized
(
def
from_tokenized
(
cls
,
cls
,
...
@@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
...
@@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
)
->
"FlashCausalLMBatch"
:
sliding_window
=
get_sliding_windows
()
position_ids
=
[]
position_ids
=
[]
speculative_ids
=
[]
cu_seqlen_prefill
=
[
0
]
cu_seqlen_prefill
=
[
0
]
needed_blocks_slots
=
[]
start_slots
=
[]
start_slots
=
[]
slot_indices
=
[]
slot_indices
=
[]
prefill_cache_indices
=
[]
input_lengths
=
[]
input_lengths
=
[]
prefix_offsets
=
[]
prefix_offsets
=
[]
...
@@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
...
@@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length
=
0
cumulative_max_length
=
0
prefill_out_cumulative_length
=
0
prefill_out_cumulative_length
=
0
blocks
=
0
num_
blocks
=
0
max_seqlen
=
0
max_seqlen
=
0
max_length
=
0
max_length
=
0
max_blocks
=
0
max_blocks
=
0
block_tables
=
[]
slots
=
[]
# Parse batch
# Parse batch
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
...
@@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
...
@@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
speculative_length
=
get_speculate
()
speculative_length
=
get_speculate
()
speculative_length
=
0
if
speculative_length
is
None
else
speculative_length
speculative_length
=
0
if
speculative_length
is
None
else
speculative_length
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
blocks
+=
needed_blocks
# blocks and slots can be empty (for example in warmup)
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
if
not
r
.
blocks
:
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
request_blocks
=
[
b
for
b
in
range
(
num_blocks
,
num_blocks
+
needed_blocks
)
]
request_slots
=
[
s
for
b
in
request_blocks
for
s
in
range
(
b
*
BLOCK_SIZE
,
(
b
+
1
)
*
BLOCK_SIZE
)
]
else
:
request_blocks
=
r
.
blocks
request_slots
=
r
.
slots
block_tables
.
append
(
request_blocks
)
slots
.
extend
(
request_slots
[:
total_tokens
])
num_blocks
+=
len
(
request_blocks
)
start_slots
.
append
(
cumulative_max_length
)
start_slots
.
append
(
cumulative_max_length
)
request_slot_indices
=
torch
.
arange
(
request_slot_indices
=
torch
.
arange
(
...
@@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
...
@@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
)
)
slot_indices
.
append
(
request_slot_indices
)
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
if
sliding_window
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
sliding_window
),
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
)
prefill_cache_indices
.
append
(
request_prefill_cache_indices
)
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
...
@@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length
+=
input_length
cumulative_length
+=
input_length
cumulative_max_length
+=
total_tokens
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed
_blocks
)
max_blocks
=
max
(
max_blocks
,
len
(
request
_blocks
)
)
max_length
=
max
(
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
+
speculative_length
max_length
,
input_length
+
max_new_tokens
+
speculative_length
)
)
...
@@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
...
@@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
sliding_window
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
else
:
input_ids
=
all_input_ids
[
0
]
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
slot_indices
=
slot_indices
[
0
]
if
sliding_window
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
)
position_ids
=
position_ids
.
to
(
device
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
sliding_window
is
not
None
else
None
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths_tensor
=
torch
.
tensor
(
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
...
@@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
...
@@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens
,
device
=
device
,
dtype
=
torch
.
int64
top_n_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
)
slots
=
torch
.
tensor
(
slots
,
dtype
=
torch
.
int64
,
device
=
device
)
block_tables_tensor
=
torch
.
zeros
(
(
len
(
block_tables
),
max_blocks
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
for
i
,
request_blocks
in
enumerate
(
block_tables
):
block_tables_tensor
[
i
,
:
len
(
request_blocks
)]
=
torch
.
tensor
(
request_blocks
)
block_tables_tensor
=
block_tables_tensor
.
to
(
device
)
return
cls
(
return
cls
(
batch_id
=
pb
.
id
,
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
requests
=
pb
.
requests
,
...
@@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
...
@@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
prefill_cache_indices
=
prefill_cache_indices
,
start_slots
=
start_slots
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
needed_blocks_slots
,
block_tables
=
block_tables
,
block_tables
=
None
,
block_tables_tensor
=
block_tables_tensor
,
block_tables_tensor
=
None
,
slots
=
slots
,
slots
=
None
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
...
@@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
...
@@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
num_
blocks
=
num_
blocks
,
max_blocks
=
max_blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
None
,
speculative_ids
=
None
,
)
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
batch_tokenized_inputs
=
cls
.
batch_tokenized_inputs
(
pb
.
requests
,
tokenizer
)
return
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
@
tracer
.
start_as_current_span
(
"filter"
)
@
tracer
.
start_as_current_span
(
"filter"
)
def
filter
(
self
,
request_ids
:
List
[
int
])
->
"FlashCausalLMBatch"
:
def
filter
(
self
,
request_ids
:
List
[
int
])
->
"FlashCausalLMBatch"
:
if
len
(
request_ids
)
==
0
:
if
len
(
request_ids
)
==
0
:
...
@@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
[]
stopping_criterias
=
[]
top_n_tokens
=
[]
top_n_tokens
=
[]
blocks
=
0
num_
blocks
=
0
max_blocks
=
0
max_blocks
=
0
# Cumulative length
# Cumulative length
cumulative_max_length
=
0
cumulative_max_length
=
0
...
@@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
)
)
request_block_table
=
self
.
block_tables
[
idx
]
request_block_table
=
self
.
block_tables
[
idx
]
blocks
+=
len
(
request_block_table
)
num_
blocks
+=
len
(
request_block_table
)
block_tables
.
append
(
request_block_table
)
block_tables
.
append
(
request_block_table
)
start_slots
.
append
(
cumulative_max_length
)
start_slots
.
append
(
cumulative_max_length
)
...
@@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
...
@@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
max_blocks
=
max
(
max_blocks
,
len
(
request_block_table
))
max_blocks
=
max
(
max_blocks
,
len
(
request_block_table
))
block_indices_to_free
=
[]
# Iterate on all requests
for
i
,
r
in
enumerate
(
self
.
requests
):
# Filter requests that are not part of the new batch
if
r
.
id
not
in
requests_idx_mapping
.
keys
():
block_indices_to_free
.
extend
(
self
.
block_tables
[
i
])
# Free blocks
get_cache_manager
().
free
(
block_indices_to_free
)
# Needed to avoid dropping blocks when the batches will go out of scope
self
.
block_tables
=
None
# Index into tensors
# Index into tensors
input_ids
=
self
.
input_ids
[
indices
]
input_ids
=
self
.
input_ids
[
indices
]
position_ids
=
self
.
position_ids
[
indices
]
position_ids
=
self
.
position_ids
[
indices
]
...
@@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
...
@@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
cu_seqlen_prefill
=
None
,
prefill_cache_indices
=
None
,
start_slots
=
start_slots
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
None
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
block_tables_tensor
=
block_tables_tensor
,
block_tables_tensor
=
block_tables_tensor
,
slots
=
slots
,
slots
=
slots
,
...
@@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
num_
blocks
=
num_
blocks
,
max_blocks
=
max_blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
speculative_ids
,
speculative_ids
=
speculative_ids
,
)
)
...
@@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
requests
=
[]
requests
=
[]
requests_idx_mapping
=
{}
requests_idx_mapping
=
{}
blocks
=
0
num_
blocks
=
0
total_batch_size
=
0
total_batch_size
=
0
total_slots
=
0
total_slots
=
0
max_blocks
=
0
max_blocks
=
0
...
@@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
for
b
in
batches
:
for
b
in
batches
:
total_batch_size
+=
len
(
b
)
total_batch_size
+=
len
(
b
)
total_slots
+=
len
(
b
.
slots
)
total_slots
+=
len
(
b
.
slots
)
blocks
+=
b
.
blocks
num_
blocks
+=
b
.
num_
blocks
speculative_length
=
(
speculative_length
=
(
b
.
speculative_ids
.
shape
[
1
]
if
b
.
speculative_ids
is
not
None
else
0
b
.
speculative_ids
.
shape
[
1
]
if
b
.
speculative_ids
is
not
None
else
0
)
)
...
@@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
...
@@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
else
None
else
None
)
)
# Needed to avoid dropping blocks when the batches will go out of scope
for
b
in
batches
:
b
.
block_tables
=
None
del
b
return
cls
(
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
requests
=
requests
,
...
@@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
...
@@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
cu_seqlen_prefill
=
None
,
prefill_cache_indices
=
None
,
start_slots
=
start_slots
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
None
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
block_tables_tensor
=
block_tables_tensor
,
block_tables_tensor
=
block_tables_tensor
,
slots
=
slots
,
slots
=
slots
,
...
@@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
...
@@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
num_
blocks
=
num_
blocks
,
max_blocks
=
max_blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
speculative_ids
,
speculative_ids
=
speculative_ids
,
)
)
def
__del__
(
self
):
if
self
.
block_tables
is
not
None
and
self
.
block_tables
:
# Free blocks
get_cache_manager
().
free
(
list
(
itertools
.
chain
.
from_iterable
(
self
.
block_tables
))
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
requests
)
return
len
(
self
.
requests
)
...
@@ -702,6 +732,7 @@ class FlashCausalLM(Model):
...
@@ -702,6 +732,7 @@ class FlashCausalLM(Model):
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
cuda_graphs
=
{}
self
.
cuda_graphs
=
{}
self
.
kv_cache
=
[]
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
model
=
model
,
model
=
model
,
...
@@ -718,6 +749,43 @@ class FlashCausalLM(Model):
...
@@ -718,6 +749,43 @@ class FlashCausalLM(Model):
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
return
FlashCausalLMBatch
return
FlashCausalLMBatch
def
max_past
(
self
)
->
int
:
return
getattr
(
self
.
model
,
"max_past"
,
None
)
def
init_kv_cache
(
self
,
num_blocks
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
self
.
kv_cache
=
[]
empty_cache
()
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
if
SYSTEM
==
"xpu"
:
x
=
1
else
:
x
=
BLOCK_SIZE
//
element_size
self
.
kv_cache
=
[
(
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
//
x
,
BLOCK_SIZE
,
x
),
dtype
=
dtype
,
device
=
device
,
),
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
,
BLOCK_SIZE
),
dtype
=
dtype
,
device
=
device
,
),
)
for
_
in
range
(
num_layers
)
]
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
@@ -728,12 +796,11 @@ class FlashCausalLM(Model):
...
@@ -728,12 +796,11 @@ class FlashCausalLM(Model):
.
repeat
(
bs
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
.
reshape
((
bs
,
max_bt
))
)
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"kv_cache"
:
self
.
kv_cache
,
"block_tables"
:
block_tables
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
"input_lengths"
:
input_lengths
,
...
@@ -747,11 +814,12 @@ class FlashCausalLM(Model):
...
@@ -747,11 +814,12 @@ class FlashCausalLM(Model):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
kv_cache
=
self
.
kv_cache
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slots
=
slots
,
slots
=
slots
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
lm_head_indices
=
None
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -761,11 +829,12 @@ class FlashCausalLM(Model):
...
@@ -761,11 +829,12 @@ class FlashCausalLM(Model):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
kv_cache
=
self
.
kv_cache
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slots
=
slots
,
slots
=
slots
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
lm_head_indices
=
None
,
)
)
self
.
cuda_graphs
[
bs
][
"logits"
]
=
logits
self
.
cuda_graphs
[
bs
][
"logits"
]
=
logits
...
@@ -777,17 +846,16 @@ class FlashCausalLM(Model):
...
@@ -777,17 +846,16 @@ class FlashCausalLM(Model):
empty_cache
()
empty_cache
()
try
:
try
:
cache_manager
=
set_cache_manager
(
self
.
init_kv_cache
(
batch
.
blocks
,
batch
.
num_
blocks
,
self
.
num_layers
,
self
.
num_layers
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
head_size
,
self
.
sliding_window
is
not
None
,
self
.
dtype
,
self
.
dtype
,
self
.
device
,
self
.
device
,
)
)
max_bt
=
batch
.
max_blocks
max_bt
=
batch
.
max_blocks
max_s
=
max_bt
*
get_cache_manager
().
block_size
max_s
=
max_bt
*
BLOCK_SIZE
if
SYSTEM
==
"rocm"
and
os
.
environ
.
get
(
"PYTORCH_TUNABLEOP_ENABLED"
,
False
):
if
SYSTEM
==
"rocm"
and
os
.
environ
.
get
(
"PYTORCH_TUNABLEOP_ENABLED"
,
False
):
torch
.
cuda
.
tunable
.
tuning_enable
(
False
)
torch
.
cuda
.
tunable
.
tuning_enable
(
False
)
...
@@ -811,19 +879,17 @@ class FlashCausalLM(Model):
...
@@ -811,19 +879,17 @@ class FlashCausalLM(Model):
num_blocks
=
(
num_blocks
=
(
# Leave 5% for some wiggle room
# Leave 5% for some wiggle room
int
((
free_memory
*
0.95
)
//
total_cache_size
)
int
((
free_memory
*
0.95
)
//
total_cache_size
)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
# Add batch.
num_
blocks as we allocated it above, so it is included in the peak memory.
+
cache_manager
.
num_blocks
+
batch
.
num_blocks
)
)
del
batch
del
batch
del
cache_manager
se
t_cache_manager
(
se
lf
.
init_kv_cache
(
num_blocks
,
num_blocks
,
self
.
num_layers
,
self
.
num_layers
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
head_size
,
self
.
sliding_window
is
not
None
,
self
.
dtype
,
self
.
dtype
,
self
.
device
,
self
.
device
,
)
)
...
@@ -889,7 +955,6 @@ class FlashCausalLM(Model):
...
@@ -889,7 +955,6 @@ class FlashCausalLM(Model):
input_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
kv_cache
=
get_cache_manager
().
kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths
=
torch
.
ones
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
@@ -901,12 +966,13 @@ class FlashCausalLM(Model):
...
@@ -901,12 +966,13 @@ class FlashCausalLM(Model):
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
=
torch
.
tensor
(
[
0
,
seqlen
],
device
=
self
.
device
,
dtype
=
torch
.
int32
[
0
,
seqlen
],
device
=
self
.
device
,
dtype
=
torch
.
int32
),
),
kv_cache
=
get_cache_manager
()
.
kv_cache
,
kv_cache
=
self
.
kv_cache
,
block_tables
=
None
,
block_tables
=
None
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
slots
=
slots
,
slots
=
slots
,
max_s
=
seqlen
,
max_s
=
seqlen
,
lm_head_indices
=
None
,
lm_head_indices
=
None
,
prefill_cache_indices
=
None
,
)
)
def
forward
(
def
forward
(
...
@@ -917,7 +983,7 @@ class FlashCausalLM(Model):
...
@@ -917,7 +983,7 @@ class FlashCausalLM(Model):
input_ids
=
batch
.
input_ids
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
input_lengths
=
batch
.
input_lengths_tensor
...
@@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
...
@@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
input_ids
=
batch
.
input_ids
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
lm_head_indices
=
batch
.
prefill_head_indices
if
cu_seqlen_prefill
is
None
and
self
.
max_past
()
is
not
None
:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
bs
=
input_ids
.
shape
[
0
]
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
if
sorted_padded_bs
:
if
sorted_padded_bs
:
...
@@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
...
@@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
cuda_graph
=
None
cuda_graph
=
None
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
return
self
.
model
.
forward
(
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
...
@@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
...
@@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
slots
=
slots
,
slots
=
slots
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
lm_head_indices
=
lm_head_indices
,
)
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
,
speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
# Static inputs are potentially padded
...
@@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
...
@@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
prefill
=
batch
.
cu_seqlen_prefill
is
not
None
prefill
=
batch
.
cu_seqlen_prefill
is
not
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
if
batch
.
needed_blocks_slots
:
out
,
speculative_logits
=
self
.
forward
(
batch
)
# Allocate blocks to this batch
block_tables
,
block_tables_tensor
,
slots
=
get_cache_manager
().
allocate
(
batch
.
needed_blocks_slots
,
batch
.
blocks
,
batch
.
max_blocks
,
batch
.
input_ids
.
device
,
)
batch
.
needed_blocks_slots
=
None
batch
.
block_tables
=
block_tables
batch
.
block_tables_tensor
=
block_tables_tensor
batch
.
slots
=
slots
try
:
out
,
speculative_logits
=
self
.
forward
(
batch
)
except
Exception
as
e
:
del
batch
raise
e
if
prefill
:
if
prefill
:
next_token_logits
=
(
next_token_logits
=
(
...
@@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
...
@@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
batch
.
all_input_ids
[
i
]
=
all_input_ids
batch
.
all_input_ids
[
i
]
=
all_input_ids
if
stopped
:
if
stopped
:
del
batch
# No need to return a batch if we know that all requests stopped
# No need to return a batch if we know that all requests stopped
forward_ns
=
start_decode
-
start
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
decode_ns
=
time
.
time_ns
()
-
start_decode
...
...
server/text_generation_server/models/flash_mistral.py
View file @
8aece3bd
import
math
import
torch
import
torch
import
torch.distributed
import
torch.distributed
import
numpy
as
np
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
,
AutoTokenizer
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
Tuple
,
Type
from
typing
import
Optional
,
Tuple
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
,
BLOCK_SIZE
from
text_generation_server.models.flash_causal_lm
import
set_sliding_window
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
)
from
text_generation_server.models.custom_modeling.flash_mistral_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_mistral_modeling
import
(
FlashMistralForCausalLM
,
FlashMistralForCausalLM
,
MistralConfig
,
MistralConfig
,
)
)
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.utils
import
(
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
initialize_torch_distributed
,
weight_files
,
weight_files
,
Weights
,
Weights
,
HeterogeneousNextTokenChooser
,
StoppingCriteria
,
)
)
tracer
=
trace
.
get_tracer
(
__name__
)
# Will be set in init
SLIDING_WINDOW
:
Optional
[
int
]
=
None
SLIDING_WINDOW_BLOCKS
:
Optional
[
int
]
=
None
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.utils.import_utils
import
SYSTEM
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
if
torch
.
cuda
.
is_available
()
else
None
tracer
=
trace
.
get_tracer
(
__name__
)
def
set_sliding_window
(
sliding_window
:
int
,
sliding_window_blocks
:
int
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW
=
sliding_window
SLIDING_WINDOW_BLOCKS
=
sliding_window_blocks
def
get_sliding_windows
()
->
Tuple
[
int
,
int
]:
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
return
SLIDING_WINDOW
,
SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
class
FlashMistralBatch
(
FlashCausalLMBatch
):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
batch_tokenized_inputs
=
cls
.
batch_tokenized_inputs
(
pb
.
requests
,
tokenizer
)
return
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
@
classmethod
def
from_tokenized
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
batch_tokenized_inputs
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
sliding_window
,
sliding_window_blocks
=
get_sliding_windows
()
position_ids
=
[]
cu_seqlen_prefill
=
[
0
]
needed_blocks_slots
=
[]
start_slots
=
[]
slot_indices
=
[]
prefill_cache_indices
=
[]
input_lengths
=
[]
prefix_offsets
=
[]
read_offsets
=
[]
all_input_ids
=
[]
requests_idx_mapping
=
{}
all_prefill_logprobs
=
True
no_prefill_logprobs
=
True
prefill_head_indices
=
[]
prefill_next_token_indices
=
[]
prefill_cu_outlens
=
[
0
]
next_token_chooser_parameters
=
[]
stopping_criterias
=
[]
top_n_tokens
=
[]
# Cumulative length
cumulative_length
=
0
cumulative_max_length
=
0
prefill_out_cumulative_length
=
0
blocks
=
0
max_seqlen
=
0
max_length
=
0
max_blocks
=
0
# Parse batch
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
):
# request id -> idx in list mapping
requests_idx_mapping
[
r
.
id
]
=
i
tokenized_input
=
tokenized_input
[
-
r
.
truncate
:]
if
(
tokenized_input
[
0
]
==
tokenizer
.
bos_token_id
and
tokenized_input
[
1
]
==
tokenizer
.
bos_token_id
):
tokenized_input
=
tokenized_input
[
1
:]
input_length
=
len
(
tokenized_input
)
input_lengths
.
append
(
input_length
)
prefix_offsets
.
append
(
input_length
-
5
)
read_offsets
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
# Position ids
request_position_ids
=
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
)
position_ids
.
append
(
request_position_ids
)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill
.
append
(
cumulative_length
+
input_length
)
next_token_chooser_parameters
.
append
(
r
.
parameters
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
max_new_tokens
=
stopping_criteria
.
max_new_tokens
stopping_criterias
.
append
(
stopping_criteria
)
top_n_tokens
.
append
(
r
.
top_n_tokens
)
# Paged attention
# Remove one as the first token des not have a past
speculative_length
=
get_speculate
()
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
if
sliding_window_blocks
is
not
None
:
needed_blocks
=
min
(
needed_blocks
,
sliding_window_blocks
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
start_slots
.
append
(
cumulative_max_length
)
request_slot_indices
=
torch
.
arange
(
cumulative_max_length
,
cumulative_max_length
+
input_length
,
dtype
=
torch
.
int64
,
)
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
if
sliding_window
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
sliding_window
),
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
)
prefill_cache_indices
.
append
(
request_prefill_cache_indices
)
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
if
r
.
prefill_logprobs
:
prefill_head_indices
.
append
(
request_position_ids
+
cumulative_length
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
+
input_length
-
1
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
input_length
)
prefill_out_cumulative_length
+=
input_length
else
:
prefill_head_indices
.
append
(
torch
.
tensor
(
[
cumulative_length
+
input_length
-
1
],
dtype
=
torch
.
int32
)
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
1
)
prefill_out_cumulative_length
+=
1
# Update
cumulative_length
+=
input_length
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed_blocks
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
+
speculative_length
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
next_token_chooser_parameters
,
dtype
,
device
,
tokenizer
)
start_slots
=
torch
.
tensor
(
start_slots
,
dtype
=
torch
.
int64
)
# Padded all_input_ids_tensor
all_input_ids_tensor
=
np
.
zeros
(
(
len
(
all_input_ids
),
max_length
),
dtype
=
np
.
int64
)
for
i
,
input_ids
in
enumerate
(
all_input_ids
):
all_input_ids_tensor
[
i
,
:
len
(
input_ids
)]
=
input_ids
# Create tensors on device
all_input_ids_tensor
=
torch
.
tensor
(
all_input_ids_tensor
,
dtype
=
torch
.
int64
,
device
=
device
)
if
len
(
pb
.
requests
)
>
1
:
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
sliding_window
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
if
sliding_window
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
sliding_window
is
not
None
else
None
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
)
if
all_prefill_logprobs
:
prefill_head_indices
=
None
prefill_next_token_indices
=
cu_seqlen_prefill
[
1
:]
-
1
elif
no_prefill_logprobs
:
prefill_head_indices
=
cu_seqlen_prefill
[
1
:]
-
1
prefill_next_token_indices
=
None
else
:
prefill_head_indices
=
torch
.
tensor
(
torch
.
cat
(
prefill_head_indices
),
dtype
=
torch
.
int64
,
device
=
device
)
prefill_next_token_indices
=
torch
.
tensor
(
prefill_next_token_indices
,
dtype
=
torch
.
int64
,
device
=
device
)
top_n_tokens_tensor
=
torch
.
tensor
(
top_n_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
needed_blocks_slots
,
block_tables
=
None
,
block_tables_tensor
=
None
,
slots
=
None
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
prefill_cu_outlens
=
prefill_cu_outlens
,
input_lengths
=
input_lengths
,
input_lengths_tensor
=
input_lengths_tensor
,
prefix_offsets
=
prefix_offsets
,
read_offsets
=
read_offsets
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_chooser
=
next_token_chooser
,
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
max_blocks
=
max_blocks
,
prefill_cache_indices
=
prefill_cache_indices
,
speculative_ids
=
None
,
)
class
BaseFlashMistral
(
FlashCausalLM
):
class
BaseFlashMistral
(
FlashCausalLM
):
...
@@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows
# Set context windows
if
getattr
(
config
,
"sliding_window"
,
None
)
is
not
None
:
if
getattr
(
config
,
"sliding_window"
,
None
)
is
not
None
:
set_sliding_window
(
set_sliding_window
(
config
.
sliding_window
)
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
else
:
else
:
config
.
sliding_window
=
None
config
.
sliding_window
=
None
...
@@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
model
.
model
.
head_size
,
model
.
model
.
head_size
,
)
)
def
max_past
(
self
)
->
int
:
return
self
.
model
.
max_past
@
property
def
batch_type
(
self
)
->
Type
[
FlashMistralBatch
]:
return
FlashMistralBatch
def
tunableop_warmup
(
self
,
seqlen
:
int
):
input_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
kv_cache
=
get_cache_manager
().
kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths
=
torch
.
ones
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
torch
.
tensor
(
[
0
,
seqlen
],
device
=
self
.
device
,
dtype
=
torch
.
int32
),
kv_cache
=
get_cache_manager
().
kv_cache
,
block_tables
=
None
,
input_lengths
=
input_lengths
,
slots
=
slots
,
max_s
=
seqlen
,
lm_head_indices
=
None
,
prefill_cache_indices
=
None
,
)
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
*
max_s
block_tables
=
(
torch
.
arange
(
max_bt
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
torch
.
cuda
.
synchronize
()
# Run once outside to warmup
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
self
.
cuda_graphs
[
bs
][
"logits"
]
=
logits
self
.
cuda_graphs
[
bs
][
"speculative_logits"
]
=
speculative_logits
torch
.
cuda
.
synchronize
()
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
speculative_ids
=
batch
.
speculative_ids
B
,
speculative_length
=
speculative_ids
.
shape
new_length
=
speculative_length
+
1
new_input_ids
=
torch
.
cat
(
[
input_ids
.
unsqueeze
(
-
1
),
speculative_ids
],
dim
=
1
).
reshape
(
-
1
)
arange
=
torch
.
arange
(
new_length
,
device
=
position_ids
.
device
).
unsqueeze
(
0
)
arange_int
=
arange
.
to
(
dtype
=
torch
.
int32
)
new_position_ids
=
(
position_ids
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange
).
view
(
-
1
)
slots
=
(
slots
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
input_lengths
=
(
input_lengths
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
# Add Copy the block tables for all members
block_tables
=
(
block_tables
.
unsqueeze
(
1
)
.
expand
(
B
,
new_length
,
-
1
)
.
reshape
(
B
*
new_length
,
-
1
)
.
contiguous
()
)
max_s
=
max_s
+
speculative_length
input_ids
=
new_input_ids
position_ids
=
new_position_ids
else
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
if
cu_seqlen_prefill
is
None
and
self
.
max_past
()
is
not
None
:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
,
speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph
[
"input_ids"
][:
input_ids
.
shape
[
0
]]
=
input_ids
cuda_graph
[
"position_ids"
][:
position_ids
.
shape
[
0
]]
=
position_ids
cuda_graph
[
"block_tables"
][
:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]
]
=
block_tables
cuda_graph
[
"slots"
].
fill_
(
-
1
)
cuda_graph
[
"slots"
][:
slots
.
shape
[
0
]]
=
slots
cuda_graph
[
"input_lengths"
].
zero_
()
cuda_graph
[
"input_lengths"
][:
input_lengths
.
shape
[
0
]]
=
input_lengths
# Replay the graph
cuda_graph
[
"graph"
].
replay
()
# Slice output to the correct shape
speculative_logits
=
(
cuda_graph
[
"speculative_logits"
][:
bs
]
if
cuda_graph
[
"speculative_logits"
]
is
not
None
else
None
)
logits
=
cuda_graph
[
"logits"
][:
bs
]
return
logits
,
speculative_logits
class
FlashMistral
(
BaseFlashMistral
):
class
FlashMistral
(
BaseFlashMistral
):
def
__init__
(
def
__init__
(
...
...
server/text_generation_server/models/flash_qwen2.py
View file @
8aece3bd
...
@@ -7,7 +7,6 @@ from opentelemetry import trace
...
@@ -7,7 +7,6 @@ from opentelemetry import trace
from
transformers
import
AutoTokenizer
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
from
typing
import
Optional
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
BaseFlashMistral
,
set_sliding_window
,
set_sliding_window
,
...
@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
...
@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows
# Set context windows
if
config
.
sliding_window
is
not
None
:
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
set_sliding_window
(
config
.
sliding_window
)
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/flash_starcoder2.py
View file @
8aece3bd
...
@@ -6,7 +6,6 @@ from typing import Optional
...
@@ -6,7 +6,6 @@ from typing import Optional
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
BaseFlashMistral
,
set_sliding_window
,
set_sliding_window
,
...
@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
...
@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows
# Set context windows
if
config
.
sliding_window
is
not
None
:
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
set_sliding_window
(
config
.
sliding_window
)
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/vlm_causal_lm.py
View file @
8aece3bd
...
@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
...
@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
transformers.image_processing_utils
import
select_best_resolution
from
transformers.image_processing_utils
import
select_best_resolution
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
from
text_generation_server.models.flash_mistral
import
(
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
BaseFlashMistral
,
FlashMistralBatch
,
)
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
)
)
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
...
@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
...
@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return
image
return
image
class
VlmCausalLMBatch
(
Flash
Mistral
Batch
):
class
VlmCausalLMBatch
(
Flash
CausalLM
Batch
):
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
...
@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
...
@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids
=
batch
.
input_ids
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
input_lengths
=
batch
.
input_lengths_tensor
...
@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
...
@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids
=
batch
.
input_ids
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
input_lengths
=
batch
.
input_lengths_tensor
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment