Unverified Commit 8aece3bd authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: move allocation logic to rust (#1835)

Close #2007
parent 9ffe1f1e
...@@ -7,7 +7,6 @@ from opentelemetry import trace ...@@ -7,7 +7,6 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
set_sliding_window, set_sliding_window,
...@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral): ...@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
set_sliding_window( set_sliding_window(config.sliding_window)
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
......
...@@ -6,7 +6,6 @@ from typing import Optional ...@@ -6,7 +6,6 @@ from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
set_sliding_window, set_sliding_window,
...@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral): ...@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
set_sliding_window( set_sliding_window(config.sliding_window)
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
......
...@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict ...@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image: ...@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image return image
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
...@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
...@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment