Unverified Commit 8aece3bd authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: move allocation logic to rust (#1835)

Close #2007
parent 9ffe1f1e
......@@ -7,7 +7,6 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
......@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
......
......@@ -6,7 +6,6 @@ from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
......@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
......
......@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
tracer = trace.get_tracer(__name__)
......@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image
class VlmCausalLMBatch(FlashMistralBatch):
class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
......@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
......@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment