Commit 5a1cf2f0 authored by huangwb's avatar huangwb
Browse files

Merge tag 'v2.0.2' into dev-rocm

parents 24f58bb6 6073ece4
...@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): ...@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def image_text_replacement(image_input, config, image_id) -> str:
if config.model_type == "idefics2":
# TODO technically depends on image splitting which is not implemented.
num_features = 320
return (
"<fake_token_around_image>"
+ "<image>" * num_features
+ "<fake_token_around_image>"
)
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
return "<image>" * num_features
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def get_unpadded_features(
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int: def get_number_of_features(height: int, width: int, config) -> int:
# From config # From config
# Hardcoded for CLIP for now # Hardcoded for CLIP for now
...@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int: ...@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints, image_grid_pinpoints,
image_size, image_size,
) )
unpadded_features, newline_features = get_unpadded_features(
height_of_patch = math.ceil(height / width * npatches) height, width, npatches, num_patch_height, num_patch_width
)
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
# They are only added after width
newline_features = height_of_patch * num_patch_width
# The base patch covers the entire image # The base patch covers the entire image
base_features = npatches**2 base_features = npatches**2
return unpadded_features + newline_features + base_features return unpadded_features + newline_features + base_features
...@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image: ...@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image return image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashMistralBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
@classmethod @classmethod
...@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def concatenate(cls, batches): def concatenate(cls, batches):
batch = super(VlmCausalLMBatch, cls).concatenate(batches) batch = super(VlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def filter(self, request_ids: List[int]): def filter(self, request_ids: List[int]):
batch = super().filter(request_ids) batch = super().filter(request_ids)
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for r in requests: for r in requests:
chunks = split(r.inputs) chunks = split(r.inputs)
full_text = "" full_text = ""
image_id = 0
for chunk in chunks: for chunk in chunks:
if chunk["type"] == "text": if chunk["type"] == "text":
full_text += chunk["content"] full_text += chunk["content"]
...@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:" "Cannot process input image not starting with data:"
) )
image_input = processor.image_processor(image, return_tensors="pt") image_input = processor.image_processor(image, return_tensors="pt")
height, width = image_input["image_sizes"][0] full_text += image_text_replacement(image_input, config, image_id)
num_features = get_number_of_features(height, width, config)
full_text += "<image>" * num_features
image_inputs.append(image_input) image_inputs.append(image_input)
else: else:
raise RuntimeError(f"Invalid chunk type {chunk['type']}") raise RuntimeError(f"Invalid chunk type {chunk['type']}")
...@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )["input_ids"]
if image_inputs: if image_inputs:
image_inputs = { image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat( "pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0 [img["pixel_values"] for img in image_inputs], dim=0
), ),
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
} }
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else: else:
image_inputs = None image_inputs = None
return batch_tokenized_inputs, image_inputs return batch_tokenized_inputs, image_inputs
...@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch): ...@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None: if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device) batch.pixel_values = image_inputs["pixel_values"].to(device=device)
batch.image_sizes = image_inputs["image_sizes"].to(device=device) if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
else: else:
batch.pixel_values = None batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None batch.image_sizes = None
return batch return batch
...@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def batch_type(self) -> Type[VlmCausalLMBatch]: def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch return VlmCausalLMBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)
def forward( def forward(
self, batch: VlmCausalLMBatch self, batch: VlmCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral): ...@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes, image_sizes=batch.image_sizes,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
if batch.pixel_values is not None: if batch.pixel_values is not None:
batch.pixel_values = None batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None: if batch.image_sizes is not None:
batch.image_sizes = None batch.image_sizes = None
return logits, speculative_logits return logits, speculative_logits
......
...@@ -2,6 +2,7 @@ import asyncio ...@@ -2,6 +2,7 @@ import asyncio
import os import os
import torch import torch
import time import time
import signal
from grpc import aio from grpc import aio
from loguru import logger from loguru import logger
...@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor ...@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
signal_handler = SignalHandler()
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__( def __init__(
self, self,
...@@ -231,11 +247,8 @@ def serve( ...@@ -231,11 +247,8 @@ def serve(
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
try: while signal_handler.KEEP_PROCESSING:
await server.wait_for_termination() await asyncio.sleep(0.5)
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
......
...@@ -57,7 +57,14 @@ def initialize_torch_distributed(): ...@@ -57,7 +57,14 @@ def initialize_torch_distributed():
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60) options._timeout = timedelta(seconds=60)
else: else:
backend = "gloo" try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None options = None
if WORLD_SIZE == 1: if WORLD_SIZE == 1:
......
import torch import torch
def is_xpu_available():
try:
import intel_extension_for_pytorch
except ImportError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment