Unverified Commit 3212c2ad authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: optimize tensor transport (#6003)


Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent 53475674
...@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -34,8 +34,10 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
server_args: ServerArgs, server_args: ServerArgs,
_processor: VILAProcessor, _processor: VILAProcessor,
*args,
**kwargs,
) -> None: ) -> None:
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token, image_token=self._processor.tokenizer.image_token,
image_token_id=hf_config.image_token_id, image_token_id=hf_config.image_token_id,
......
...@@ -14,6 +14,7 @@ import traceback ...@@ -14,6 +14,7 @@ import traceback
import urllib.request import urllib.request
import weakref import weakref
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from io import BytesIO from io import BytesIO
from json import dumps from json import dumps
from typing import Any, Callable, List, Optional, Tuple, Type, Union from typing import Any, Callable, List, Optional, Tuple, Type, Union
...@@ -28,6 +29,24 @@ from tqdm import tqdm ...@@ -28,6 +29,24 @@ from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def execute_once(func):
has_run = None
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal has_run
if not has_run:
func(*args, **kwargs)
has_run = True
return wrapper
@execute_once
def info_once(message: str):
logger.info(message)
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
"""Convert a JSON schema to a string. """Convert a JSON schema to a string.
Parameters Parameters
......
...@@ -24,7 +24,7 @@ class VLMInputTestBase: ...@@ -24,7 +24,7 @@ class VLMInputTestBase:
model_path = None model_path = None
chat_template = None chat_template = None
processor = None processor = None
visual = None # Should be a callable for precomputed features visual = None # Should be a callable for precomputed embeddings
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -41,7 +41,7 @@ class VLMInputTestBase: ...@@ -41,7 +41,7 @@ class VLMInputTestBase:
@classmethod @classmethod
def _init_visual(cls): def _init_visual(cls):
"""Override in subclass to set up cls.visual as a callable for precomputed features.""" """Override in subclass to set up cls.visual as a callable for precomputed embeddings."""
raise NotImplementedError raise NotImplementedError
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment