Unverified Commit 9ba093b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[CI/Build] Simplify model loading for `HfRunner` (#5251)

parent 27208be6
import contextlib
import gc
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TypeVar
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoProcessor, AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
......@@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"float": torch.float,
}
AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)
_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
class HfRunner:
def wrap_device(self, input: any):
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
else:
......@@ -163,13 +160,16 @@ class HfRunner:
self,
model_name: str,
dtype: str = "half",
*,
is_embedding_model: bool = False,
is_vision_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if model_name in _EMBEDDING_MODELS:
if is_embedding_model:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
......@@ -178,8 +178,13 @@ class HfRunner:
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
else:
auto_cls = AutoModelForCausalLM
self.model = self.wrap_device(
AutoModelForCausalLM.from_pretrained(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
......
......@@ -28,7 +28,7 @@ def test_models(
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
hf_outputs = hf_model.encode(example_prompts)
del hf_model
......
......@@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
"""
model_id, vision_language_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype)
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
images=hf_images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment