Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
8aece3bd
Unverified
Commit
8aece3bd
authored
Jun 05, 2024
by
OlivierDehaene
Committed by
GitHub
Jun 05, 2024
Browse files
feat: move allocation logic to rust (#1835)
Close #2007
parent
9ffe1f1e
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
174 additions
and
613 deletions
+174
-613
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+163
-105
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+5
-492
server/text_generation_server/models/flash_qwen2.py
server/text_generation_server/models/flash_qwen2.py
+1
-4
server/text_generation_server/models/flash_starcoder2.py
server/text_generation_server/models/flash_starcoder2.py
+1
-4
server/text_generation_server/models/vlm_causal_lm.py
server/text_generation_server/models/vlm_causal_lm.py
+4
-8
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
8aece3bd
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/flash_mistral.py
View file @
8aece3bd
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/flash_qwen2.py
View file @
8aece3bd
...
...
@@ -7,7 +7,6 @@ from opentelemetry import trace
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
set_sliding_window
,
...
...
@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
set_sliding_window
(
config
.
sliding_window
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/flash_starcoder2.py
View file @
8aece3bd
...
...
@@ -6,7 +6,6 @@ from typing import Optional
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
set_sliding_window
,
...
...
@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
set_sliding_window
(
config
.
sliding_window
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/vlm_causal_lm.py
View file @
8aece3bd
...
...
@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from
transformers
import
PreTrainedTokenizerBase
from
transformers.image_processing_utils
import
select_best_resolution
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
FlashMistralBatch
,
)
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
...
...
@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return
image
class
VlmCausalLMBatch
(
Flash
Mistral
Batch
):
class
VlmCausalLMBatch
(
Flash
CausalLM
Batch
):
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
...
...
@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
...
...
@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids
=
batch
.
input_ids
position_ids
=
batch
.
position_ids
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
kv_cache
=
get_cache_manager
()
.
kv_cache
kv_cache
=
self
.
kv_cache
block_tables
=
batch
.
block_tables_tensor
slots
=
batch
.
slots
[
batch
.
slot_indices
]
input_lengths
=
batch
.
input_lengths_tensor
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment