Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8aece3bd
Unverified
Commit
8aece3bd
authored
Jun 05, 2024
by
OlivierDehaene
Committed by
GitHub
Jun 05, 2024
Browse files
feat: move allocation logic to rust (#1835)
Close #2007
parent
9ffe1f1e
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
174 additions
and
613 deletions
+174
-613
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+163
-105
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+5
-492
server/text_generation_server/models/flash_qwen2.py
server/text_generation_server/models/flash_qwen2.py
+1
-4
server/text_generation_server/models/flash_starcoder2.py
server/text_generation_server/models/flash_starcoder2.py
+1
-4
server/text_generation_server/models/vlm_causal_lm.py
server/text_generation_server/models/vlm_causal_lm.py
+4
-8
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
8aece3bd
...
...
@@ -25,11 +25,6 @@ from text_generation_server.models.types import (
Generation
,
GeneratedText
,
)
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
set_cache_manager
,
BLOCK_SIZE
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models.globals
import
MEM_POOL
,
CUDA_GRAPHS
import
text_generation_server.models.globals
as
tgi_globals
...
...
@@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
tracer
=
trace
.
get_tracer
(
__name__
)
BLOCK_SIZE
:
int
=
16
# Will be set in init
SLIDING_WINDOW
:
Optional
[
int
]
=
None
def
set_sliding_window
(
sliding_window
:
int
):
global
SLIDING_WINDOW
SLIDING_WINDOW
=
sliding_window
def
get_sliding_windows
()
->
int
:
global
SLIDING_WINDOW
return
SLIDING_WINDOW
@
dataclass
class
FlashCausalLMBatch
(
Batch
):
...
...
@@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
speculative_ids
:
torch
.
Tensor
speculative_ids
:
Optional
[
torch
.
Tensor
]
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices
:
Optional
[
torch
.
Tensor
]
# Paged Attention values
...
...
@@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
start_slots
:
torch
.
Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices
:
torch
.
Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables
:
Optional
[
List
[
List
[
int
]]
]
block_tables
:
List
[
List
[
int
]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor
:
Optional
[
torch
.
Tensor
]
block_tables_tensor
:
torch
.
Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots
:
Optional
[
torch
.
Tensor
]
slots
:
torch
.
Tensor
max_seqlen
:
int
...
...
@@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor
:
torch
.
Tensor
# Number of blocks in this batch
blocks
:
int
num_
blocks
:
int
# Maximum number of blocks
max_blocks
:
int
...
...
@@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
id
=
self
.
batch_id
,
request_ids
=
[
r
.
id
for
r
in
self
.
requests
],
size
=
len
(
self
),
max_tokens
=
self
.
blocks
*
BLOCK_SIZE
,
max_tokens
=
self
.
num_
blocks
*
BLOCK_SIZE
,
)
@
classmethod
...
...
@@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
)[
"input_ids"
]
return
batch_tokenized_inputs
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
batch_tokenized_inputs
=
cls
.
batch_tokenized_inputs
(
pb
.
requests
,
tokenizer
)
return
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
@
classmethod
def
from_tokenized
(
cls
,
...
...
@@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
sliding_window
=
get_sliding_windows
()
position_ids
=
[]
speculative_ids
=
[]
cu_seqlen_prefill
=
[
0
]
needed_blocks_slots
=
[]
start_slots
=
[]
slot_indices
=
[]
prefill_cache_indices
=
[]
input_lengths
=
[]
prefix_offsets
=
[]
...
...
@@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length
=
0
prefill_out_cumulative_length
=
0
blocks
=
0
num_
blocks
=
0
max_seqlen
=
0
max_length
=
0
max_blocks
=
0
block_tables
=
[]
slots
=
[]
# Parse batch
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
...
...
@@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
speculative_length
=
get_speculate
()
speculative_length
=
0
if
speculative_length
is
None
else
speculative_length
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
# blocks and slots can be empty (for example in warmup)
if
not
r
.
blocks
:
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
request_blocks
=
[
b
for
b
in
range
(
num_blocks
,
num_blocks
+
needed_blocks
)
]
request_slots
=
[
s
for
b
in
request_blocks
for
s
in
range
(
b
*
BLOCK_SIZE
,
(
b
+
1
)
*
BLOCK_SIZE
)
]
else
:
request_blocks
=
r
.
blocks
request_slots
=
r
.
slots
block_tables
.
append
(
request_blocks
)
slots
.
extend
(
request_slots
[:
total_tokens
])
num_blocks
+=
len
(
request_blocks
)
start_slots
.
append
(
cumulative_max_length
)
request_slot_indices
=
torch
.
arange
(
...
...
@@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
)
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
if
sliding_window
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
sliding_window
),
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
)
prefill_cache_indices
.
append
(
request_prefill_cache_indices
)
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
...
...
@@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length
+=
input_length
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed
_blocks
)
max_blocks
=
max
(
max_blocks
,
len
(
request
_blocks
)
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
+
speculative_length
)
...
...
@@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
sliding_window
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
if
sliding_window
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
sliding_window
is
not
None
else
None
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
...
...
@@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
slots
=
torch
.
tensor
(
slots
,
dtype
=
torch
.
int64
,
device
=
device
)
block_tables_tensor
=
torch
.
zeros
(
(
len
(
block_tables
),
max_blocks
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
for
i
,
request_blocks
in
enumerate
(
block_tables
):
block_tables_tensor
[
i
,
:
len
(
request_blocks
)]
=
torch
.
tensor
(
request_blocks
)
block_tables_tensor
=
block_tables_tensor
.
to
(
device
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
...
...
@@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
prefill_cache_indices
=
prefill_cache_indices
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
needed_blocks_slots
,
block_tables
=
None
,
block_tables_tensor
=
None
,
slots
=
None
,
block_tables
=
block_tables
,
block_tables_tensor
=
block_tables_tensor
,
slots
=
slots
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
...
...
@@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
num_
blocks
=
num_
blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
None
,
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
batch_tokenized_inputs
=
cls
.
batch_tokenized_inputs
(
pb
.
requests
,
tokenizer
)
return
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
@
tracer
.
start_as_current_span
(
"filter"
)
def
filter
(
self
,
request_ids
:
List
[
int
])
->
"FlashCausalLMBatch"
:
if
len
(
request_ids
)
==
0
:
...
...
@@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
[]
top_n_tokens
=
[]
blocks
=
0
num_
blocks
=
0
max_blocks
=
0
# Cumulative length
cumulative_max_length
=
0
...
...
@@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
)
request_block_table
=
self
.
block_tables
[
idx
]
blocks
+=
len
(
request_block_table
)
num_
blocks
+=
len
(
request_block_table
)
block_tables
.
append
(
request_block_table
)
start_slots
.
append
(
cumulative_max_length
)
...
...
@@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
max_blocks
=
max
(
max_blocks
,
len
(
request_block_table
))
block_indices_to_free
=
[]
# Iterate on all requests
for
i
,
r
in
enumerate
(
self
.
requests
):
# Filter requests that are not part of the new batch
if
r
.
id
not
in
requests_idx_mapping
.
keys
():
block_indices_to_free
.
extend
(
self
.
block_tables
[
i
])
# Free blocks
get_cache_manager
().
free
(
block_indices_to_free
)
# Needed to avoid dropping blocks when the batches will go out of scope
self
.
block_tables
=
None
# Index into tensors
input_ids
=
self
.
input_ids
[
indices
]
position_ids
=
self
.
position_ids
[
indices
]
...
...
@@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
prefill_cache_indices
=
None
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
None
,
block_tables
=
block_tables
,
block_tables_tensor
=
block_tables_tensor
,
slots
=
slots
,
...
...
@@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
num_
blocks
=
num_
blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
speculative_ids
,
)
...
...
@@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
requests
=
[]
requests_idx_mapping
=
{}
blocks
=
0
num_
blocks
=
0
total_batch_size
=
0
total_slots
=
0
max_blocks
=
0
...
...
@@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
for
b
in
batches
:
total_batch_size
+=
len
(
b
)
total_slots
+=
len
(
b
.
slots
)
blocks
+=
b
.
blocks
num_
blocks
+=
b
.
num_
blocks
speculative_length
=
(
b
.
speculative_ids
.
shape
[
1
]
if
b
.
speculative_ids
is
not
None
else
0
)
...
...
@@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
else
None
)
# Needed to avoid dropping blocks when the batches will go out of scope
for
b
in
batches
:
b
.
block_tables
=
None
del
b
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
...
...
@@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
prefill_cache_indices
=
None
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
None
,
block_tables
=
block_tables
,
block_tables_tensor
=
block_tables_tensor
,
slots
=
slots
,
...
...
@@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
num_
blocks
=
num_
blocks
,
max_blocks
=
max_blocks
,
speculative_ids
=
speculative_ids
,
)
def
__del__
(
self
):
if
self
.
block_tables
is
not
None
and
self
.
block_tables
:
# Free blocks
get_cache_manager
().
free
(
list
(
itertools
.
chain
.
from_iterable
(
self
.
block_tables
))
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
...
...
@@ -702,6 +732,7 @@ class FlashCausalLM(Model):
self
.
head_size
=
head_size
self
.
cuda_graphs
=
{}
self
.
kv_cache
=
[]
super
(
FlashCausalLM
,
self
).
__init__
(
model
=
model
,
...
...
@@ -718,6 +749,43 @@ class FlashCausalLM(Model):
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
return
FlashCausalLMBatch
def
max_past
(
self
)
->
int
:
return
getattr
(
self
.
model
,
"max_past"
,
None
)
def
init_kv_cache
(
self
,
num_blocks
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
self
.
kv_cache
=
[]
empty_cache
()
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
if
SYSTEM
==
"xpu"
:
x
=
1
else
:
x
=
BLOCK_SIZE
//
element_size
self
.
kv_cache
=
[
(
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
//
x
,
BLOCK_SIZE
,
x
),
dtype
=
dtype
,
device
=
device
,
),
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
,
BLOCK_SIZE
),
dtype
=
dtype
,
device
=
device
,
),
)
for
_
in
range
(
num_layers
)
]
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
...
@@ -728,12 +796,11 @@ class FlashCausalLM(Model):
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"kv_cache"
:
self
.
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
...
...
@@ -747,11 +814,12 @@ class FlashCausalLM(Model):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
kv_cache
=
self
.
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -761,11 +829,12 @@ class FlashCausalLM(Model):
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
kv_cache
=
self
.
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
self
.
cuda_graphs
[
bs
][
"logits"
]
=
logits
...
...
@@ -777,17 +846,16 @@ class FlashCausalLM(Model):
empty_cache
()
try
:
cache_manager
=
set_cache_manager
(
batch
.
blocks
,
self
.
init_kv_cache
(
batch
.
num_
blocks
,
self
.
num_layers
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
sliding_window
is
not
None
,
self
.
dtype
,
self
.
device
,
)
max_bt
=
batch
.
max_blocks
max_s
=
max_bt
*
get_cache_manager
().
block_size
max_s
=
max_bt
*
BLOCK_SIZE
if
SYSTEM
==
"rocm"
and
os
.
environ
.
get
(
"PYTORCH_TUNABLEOP_ENABLED"
,
False
):
torch
.
cuda
.
tunable
.
tuning_enable
(
False
)
...
...
@@ -811,19 +879,17 @@ class FlashCausalLM(Model):
num_blocks
=
(
# Leave 5% for some wiggle room
int
((
free_memory
*
0.95
)
//
total_cache_size
)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+
cache_manager
.
num_blocks
# Add batch.
num_
blocks as we allocated it above, so it is included in the peak memory.
+
batch
.
num_blocks
)
del
batch
del
cache_manager
se
t_cache_manager
(
se
lf
.
init_kv_cache
(
num_blocks
,
self
.
num_layers
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
sliding_window
is
not
None
,
self
.
dtype
,
self
.
device
,
)
...
...
@@ -889,7 +955,6 @@ class FlashCausalLM(Model):
input_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
kv_cache
=
get_cache_manager
().
kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths
=
torch
.
ones
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
...
@@ -901,12 +966,13 @@ class FlashCausalLM(Model):
cu_seqlen_prefill
=
torch
.
tensor
(
[
0
,
seqlen
],
device
=
self
.
device
,
dtype
=
torch
.
int32
),
kv_cache
=
get_cache_manager
()
.
kv_cache
,
kv_cache
=
self
.
kv_cache
,
block_tables
=
None
,
input_lengths
=
input_lengths
,
slots
=
slots
,
max_s
=
seqlen
,
lm_head_indices
=
None
,
prefill_cache_indices
=
None
,
)
def
forward
(
...
...
@@ -917,7 +983,7 @@ class FlashCausalLM(Model):
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
...
...
@@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
if
cu_seqlen_prefill
is
None
and
self
.
max_past
()
is
not
None
:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
if
sorted_padded_bs
:
...
...
@@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
cuda_graph
=
None
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
return
self
.
model
.
forward
(
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
...
...
@@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
,
speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
...
...
@@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
prefill
=
batch
.
cu_seqlen_prefill
is
not
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
if
batch
.
needed_blocks_slots
:
# Allocate blocks to this batch
block_tables
,
block_tables_tensor
,
slots
=
get_cache_manager
().
allocate
(
batch
.
needed_blocks_slots
,
batch
.
blocks
,
batch
.
max_blocks
,
batch
.
input_ids
.
device
,
)
batch
.
needed_blocks_slots
=
None
batch
.
block_tables
=
block_tables
batch
.
block_tables_tensor
=
block_tables_tensor
batch
.
slots
=
slots
try
:
out
,
speculative_logits
=
self
.
forward
(
batch
)
except
Exception
as
e
:
del
batch
raise
e
out
,
speculative_logits
=
self
.
forward
(
batch
)
if
prefill
:
next_token_logits
=
(
...
...
@@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
batch
.
all_input_ids
[
i
]
=
all_input_ids
if
stopped
:
del
batch
# No need to return a batch if we know that all requests stopped
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
...
...
server/text_generation_server/models/flash_mistral.py
View file @
8aece3bd
import
math
import
torch
import
torch.distributed
import
numpy
as
np
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
,
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
Tuple
,
Type
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
Tuple
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
,
BLOCK_SIZE
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
)
from
text_generation_server.models.flash_causal_lm
import
set_sliding_window
from
text_generation_server.models.custom_modeling.flash_mistral_modeling
import
(
FlashMistralForCausalLM
,
MistralConfig
,
)
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
HeterogeneousNextTokenChooser
,
StoppingCriteria
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
# Will be set in init
SLIDING_WINDOW
:
Optional
[
int
]
=
None
SLIDING_WINDOW_BLOCKS
:
Optional
[
int
]
=
None
from
text_generation_server.utils.import_utils
import
SYSTEM
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
if
torch
.
cuda
.
is_available
()
else
None
def
set_sliding_window
(
sliding_window
:
int
,
sliding_window_blocks
:
int
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW
=
sliding_window
SLIDING_WINDOW_BLOCKS
=
sliding_window_blocks
def
get_sliding_windows
()
->
Tuple
[
int
,
int
]:
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
return
SLIDING_WINDOW
,
SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
class
FlashMistralBatch
(
FlashCausalLMBatch
):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
batch_tokenized_inputs
=
cls
.
batch_tokenized_inputs
(
pb
.
requests
,
tokenizer
)
return
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
@
classmethod
def
from_tokenized
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
batch_tokenized_inputs
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
sliding_window
,
sliding_window_blocks
=
get_sliding_windows
()
position_ids
=
[]
cu_seqlen_prefill
=
[
0
]
needed_blocks_slots
=
[]
start_slots
=
[]
slot_indices
=
[]
prefill_cache_indices
=
[]
input_lengths
=
[]
prefix_offsets
=
[]
read_offsets
=
[]
all_input_ids
=
[]
requests_idx_mapping
=
{}
all_prefill_logprobs
=
True
no_prefill_logprobs
=
True
prefill_head_indices
=
[]
prefill_next_token_indices
=
[]
prefill_cu_outlens
=
[
0
]
next_token_chooser_parameters
=
[]
stopping_criterias
=
[]
top_n_tokens
=
[]
# Cumulative length
cumulative_length
=
0
cumulative_max_length
=
0
prefill_out_cumulative_length
=
0
blocks
=
0
max_seqlen
=
0
max_length
=
0
max_blocks
=
0
# Parse batch
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
):
# request id -> idx in list mapping
requests_idx_mapping
[
r
.
id
]
=
i
tokenized_input
=
tokenized_input
[
-
r
.
truncate
:]
if
(
tokenized_input
[
0
]
==
tokenizer
.
bos_token_id
and
tokenized_input
[
1
]
==
tokenizer
.
bos_token_id
):
tokenized_input
=
tokenized_input
[
1
:]
input_length
=
len
(
tokenized_input
)
input_lengths
.
append
(
input_length
)
prefix_offsets
.
append
(
input_length
-
5
)
read_offsets
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
# Position ids
request_position_ids
=
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
)
position_ids
.
append
(
request_position_ids
)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill
.
append
(
cumulative_length
+
input_length
)
next_token_chooser_parameters
.
append
(
r
.
parameters
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
max_new_tokens
=
stopping_criteria
.
max_new_tokens
stopping_criterias
.
append
(
stopping_criteria
)
top_n_tokens
.
append
(
r
.
top_n_tokens
)
# Paged attention
# Remove one as the first token des not have a past
speculative_length
=
get_speculate
()
total_tokens
=
input_length
+
max_new_tokens
-
1
+
speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
if
sliding_window_blocks
is
not
None
:
needed_blocks
=
min
(
needed_blocks
,
sliding_window_blocks
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
start_slots
.
append
(
cumulative_max_length
)
request_slot_indices
=
torch
.
arange
(
cumulative_max_length
,
cumulative_max_length
+
input_length
,
dtype
=
torch
.
int64
,
)
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
if
sliding_window
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
sliding_window
),
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
)
prefill_cache_indices
.
append
(
request_prefill_cache_indices
)
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
if
r
.
prefill_logprobs
:
prefill_head_indices
.
append
(
request_position_ids
+
cumulative_length
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
+
input_length
-
1
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
input_length
)
prefill_out_cumulative_length
+=
input_length
else
:
prefill_head_indices
.
append
(
torch
.
tensor
(
[
cumulative_length
+
input_length
-
1
],
dtype
=
torch
.
int32
)
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
1
)
prefill_out_cumulative_length
+=
1
# Update
cumulative_length
+=
input_length
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed_blocks
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
+
speculative_length
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
next_token_chooser_parameters
,
dtype
,
device
,
tokenizer
)
start_slots
=
torch
.
tensor
(
start_slots
,
dtype
=
torch
.
int64
)
# Padded all_input_ids_tensor
all_input_ids_tensor
=
np
.
zeros
(
(
len
(
all_input_ids
),
max_length
),
dtype
=
np
.
int64
)
for
i
,
input_ids
in
enumerate
(
all_input_ids
):
all_input_ids_tensor
[
i
,
:
len
(
input_ids
)]
=
input_ids
# Create tensors on device
all_input_ids_tensor
=
torch
.
tensor
(
all_input_ids_tensor
,
dtype
=
torch
.
int64
,
device
=
device
)
if
len
(
pb
.
requests
)
>
1
:
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
sliding_window
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
if
sliding_window
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
sliding_window
is
not
None
else
None
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
)
if
all_prefill_logprobs
:
prefill_head_indices
=
None
prefill_next_token_indices
=
cu_seqlen_prefill
[
1
:]
-
1
elif
no_prefill_logprobs
:
prefill_head_indices
=
cu_seqlen_prefill
[
1
:]
-
1
prefill_next_token_indices
=
None
else
:
prefill_head_indices
=
torch
.
tensor
(
torch
.
cat
(
prefill_head_indices
),
dtype
=
torch
.
int64
,
device
=
device
)
prefill_next_token_indices
=
torch
.
tensor
(
prefill_next_token_indices
,
dtype
=
torch
.
int64
,
device
=
device
)
top_n_tokens_tensor
=
torch
.
tensor
(
top_n_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
needed_blocks_slots
,
block_tables
=
None
,
block_tables_tensor
=
None
,
slots
=
None
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
prefill_cu_outlens
=
prefill_cu_outlens
,
input_lengths
=
input_lengths
,
input_lengths_tensor
=
input_lengths_tensor
,
prefix_offsets
=
prefix_offsets
,
read_offsets
=
read_offsets
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_chooser
=
next_token_chooser
,
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
max_blocks
=
max_blocks
,
prefill_cache_indices
=
prefill_cache_indices
,
speculative_ids
=
None
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
BaseFlashMistral
(
FlashCausalLM
):
...
...
@@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows
if
getattr
(
config
,
"sliding_window"
,
None
)
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
set_sliding_window
(
config
.
sliding_window
)
else
:
config
.
sliding_window
=
None
...
...
@@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
model
.
model
.
head_size
,
)
def
max_past
(
self
)
->
int
:
return
self
.
model
.
max_past
@
property
def
batch_type
(
self
)
->
Type
[
FlashMistralBatch
]:
return
FlashMistralBatch
def
tunableop_warmup
(
self
,
seqlen
:
int
):
input_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
kv_cache
=
get_cache_manager
().
kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths
=
torch
.
ones
(
seqlen
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
torch
.
tensor
(
[
0
,
seqlen
],
device
=
self
.
device
,
dtype
=
torch
.
int32
),
kv_cache
=
get_cache_manager
().
kv_cache
,
block_tables
=
None
,
input_lengths
=
input_lengths
,
slots
=
slots
,
max_s
=
seqlen
,
lm_head_indices
=
None
,
prefill_cache_indices
=
None
,
)
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
*
max_s
block_tables
=
(
torch
.
arange
(
max_bt
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
torch
.
cuda
.
synchronize
()
# Run once outside to warmup
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
self
.
cuda_graphs
[
bs
][
"logits"
]
=
logits
self
.
cuda_graphs
[
bs
][
"speculative_logits"
]
=
speculative_logits
torch
.
cuda
.
synchronize
()
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
speculative_ids
=
batch
.
speculative_ids
B
,
speculative_length
=
speculative_ids
.
shape
new_length
=
speculative_length
+
1
new_input_ids
=
torch
.
cat
(
[
input_ids
.
unsqueeze
(
-
1
),
speculative_ids
],
dim
=
1
).
reshape
(
-
1
)
arange
=
torch
.
arange
(
new_length
,
device
=
position_ids
.
device
).
unsqueeze
(
0
)
arange_int
=
arange
.
to
(
dtype
=
torch
.
int32
)
new_position_ids
=
(
position_ids
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange
).
view
(
-
1
)
slots
=
(
slots
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
input_lengths
=
(
input_lengths
.
unsqueeze
(
-
1
).
expand
(
B
,
new_length
)
+
arange_int
).
view
(
-
1
)
# Add Copy the block tables for all members
block_tables
=
(
block_tables
.
unsqueeze
(
1
)
.
expand
(
B
,
new_length
,
-
1
)
.
reshape
(
B
*
new_length
,
-
1
)
.
contiguous
()
)
max_s
=
max_s
+
speculative_length
input_ids
=
new_input_ids
position_ids
=
new_position_ids
else
:
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
().
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
if
cu_seqlen_prefill
is
None
and
self
.
max_past
()
is
not
None
:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
,
speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph
[
"input_ids"
][:
input_ids
.
shape
[
0
]]
=
input_ids
cuda_graph
[
"position_ids"
][:
position_ids
.
shape
[
0
]]
=
position_ids
cuda_graph
[
"block_tables"
][
:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]
]
=
block_tables
cuda_graph
[
"slots"
].
fill_
(
-
1
)
cuda_graph
[
"slots"
][:
slots
.
shape
[
0
]]
=
slots
cuda_graph
[
"input_lengths"
].
zero_
()
cuda_graph
[
"input_lengths"
][:
input_lengths
.
shape
[
0
]]
=
input_lengths
# Replay the graph
cuda_graph
[
"graph"
].
replay
()
# Slice output to the correct shape
speculative_logits
=
(
cuda_graph
[
"speculative_logits"
][:
bs
]
if
cuda_graph
[
"speculative_logits"
]
is
not
None
else
None
)
logits
=
cuda_graph
[
"logits"
][:
bs
]
return
logits
,
speculative_logits
class
FlashMistral
(
BaseFlashMistral
):
def
__init__
(
...
...
server/text_generation_server/models/flash_qwen2.py
View file @
8aece3bd
...
...
@@ -7,7 +7,6 @@ from opentelemetry import trace
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
set_sliding_window
,
...
...
@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
set_sliding_window
(
config
.
sliding_window
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/flash_starcoder2.py
View file @
8aece3bd
...
...
@@ -6,7 +6,6 @@ from typing import Optional
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
set_sliding_window
,
...
...
@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
set_sliding_window
(
config
.
sliding_window
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/vlm_causal_lm.py
View file @
8aece3bd
...
...
@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from
transformers
import
PreTrainedTokenizerBase
from
transformers.image_processing_utils
import
select_best_resolution
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
FlashMistralBatch
,
)
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
...
...
@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return
image
class
VlmCausalLMBatch
(
Flash
Mistral
Batch
):
class
VlmCausalLMBatch
(
Flash
CausalLM
Batch
):
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
...
...
@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
...
...
@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment