Commit 7b7dd042 authored by Baber's avatar Baber
Browse files

modularize vllm

parent 8fada609
import copy
from typing import Dict, List, Optional
import json
from typing import Dict, List, NamedTuple, Optional
import transformers
from more_itertools import distribute
......@@ -29,6 +30,13 @@ except ModuleNotFoundError:
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
class JsonChatStr(NamedTuple):
prompt: str
def encode(self, encoding):
return self.prompt.encode(encoding)
@register_model("vllm-vlm")
class VLLM_VLM(VLLM):
MULTIMODAL = True
......@@ -43,6 +51,7 @@ class VLLM_VLM(VLLM):
max_images: int = 999,
**kwargs,
):
self.pretrained = pretrained
if max_images != 999:
kwargs["limit_mm_per_prompt"] = {"image": max_images}
eval_logger.info(f"Setting limit_mm_per_prompt[image] to {max_images}")
......@@ -90,6 +99,12 @@ class VLLM_VLM(VLLM):
outputs.append(inputs)
return outputs
def _generate(self, model, *args, **kwargs):
if "pixtral" not in self.pretrained:
return model.generate(*args, **kwargs)
else:
model.chat(**kwargs)
def _model_generate(
self,
requests: List[List[dict]] = None,
......@@ -116,7 +131,7 @@ class VLLM_VLM(VLLM):
model_args: dict, sampling_params, requests: List[List[dict]]
):
llm = LLM(**model_args)
return llm.generate(requests, sampling_params=sampling_params)
return self._generate(llm, requests, sampling_params=sampling_params)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
......@@ -130,14 +145,16 @@ class VLLM_VLM(VLLM):
return undistribute(results)
if self.lora_request is not None:
outputs = self.model.generate(
outputs = self._generate(
self.model,
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
outputs = self._generate(
self.model,
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
......@@ -194,12 +211,14 @@ class VLLM_VLM(VLLM):
raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
)
return self.processor.apply_chat_template(
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
if hasattr(self.processor, "apply_chat_template"):
return self.processor.apply_chat_template(
chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
else:
return JsonChatStr(json.dumps(chat_history))
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment