Commit 7b7dd042 authored by Baber's avatar Baber
Browse files

modularize vllm

parent 8fada609
import copy import copy
from typing import Dict, List, Optional import json
from typing import Dict, List, NamedTuple, Optional
import transformers import transformers
from more_itertools import distribute from more_itertools import distribute
...@@ -29,6 +30,13 @@ except ModuleNotFoundError: ...@@ -29,6 +30,13 @@ except ModuleNotFoundError:
DEFAULT_IMAGE_PLACEHOLDER = "<image>" DEFAULT_IMAGE_PLACEHOLDER = "<image>"
class JsonChatStr(NamedTuple):
prompt: str
def encode(self, encoding):
return self.prompt.encode(encoding)
@register_model("vllm-vlm") @register_model("vllm-vlm")
class VLLM_VLM(VLLM): class VLLM_VLM(VLLM):
MULTIMODAL = True MULTIMODAL = True
...@@ -43,6 +51,7 @@ class VLLM_VLM(VLLM): ...@@ -43,6 +51,7 @@ class VLLM_VLM(VLLM):
max_images: int = 999, max_images: int = 999,
**kwargs, **kwargs,
): ):
self.pretrained = pretrained
if max_images != 999: if max_images != 999:
kwargs["limit_mm_per_prompt"] = {"image": max_images} kwargs["limit_mm_per_prompt"] = {"image": max_images}
eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}") eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}")
...@@ -90,6 +99,12 @@ class VLLM_VLM(VLLM): ...@@ -90,6 +99,12 @@ class VLLM_VLM(VLLM):
outputs.append(inputs) outputs.append(inputs)
return outputs return outputs
def _generate(self, model, *args, **kwargs):
if "pixtral" not in self.pretrained:
return model.generate(*args, **kwargs)
else:
model.chat(**kwargs)
def _model_generate( def _model_generate(
self, self,
requests: List[List[dict]] = None, requests: List[List[dict]] = None,
...@@ -116,7 +131,7 @@ class VLLM_VLM(VLLM): ...@@ -116,7 +131,7 @@ class VLLM_VLM(VLLM):
model_args: dict, sampling_params, requests: List[List[dict]] model_args: dict, sampling_params, requests: List[List[dict]]
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate(requests, sampling_params=sampling_params) return self._generate(llm, requests, sampling_params=sampling_params)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
...@@ -130,14 +145,16 @@ class VLLM_VLM(VLLM): ...@@ -130,14 +145,16 @@ class VLLM_VLM(VLLM):
return undistribute(results) return undistribute(results)
if self.lora_request is not None: if self.lora_request is not None:
outputs = self.model.generate( outputs = self._generate(
self.model,
requests, requests,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request, lora_request=self.lora_request,
) )
else: else:
outputs = self.model.generate( outputs = self._generate(
self.model,
requests, requests,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
...@@ -194,12 +211,14 @@ class VLLM_VLM(VLLM): ...@@ -194,12 +211,14 @@ class VLLM_VLM(VLLM):
raise ValueError( raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
) )
if hasattr(self.processor, "apply_chat_template"):
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
chat_history, chat_history,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
) )
else:
return JsonChatStr(json.dumps(chat_history))
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment