import warnings from .base import BaseModel from ..smp import * class VisualGLM(BaseModel): INSTALL_REQ = False INTERLEAVE = False def __init__(self, model_path='THUDM/visualglm-6b', **kwargs): try: import sat except Exception as err: logging.critical('Please install SwissArmyTransformer to use VisualGLM') raise err assert model_path is not None self.model_path = model_path from transformers import AutoModel from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda() self.model = model self.kwargs = kwargs warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') def generate_inner(self, message, dataset=None): prompt, image_path = self.message_to_promptimg(message, dataset=dataset) output, _ = self.model.chat( image_path=image_path, tokenizer=self.tokenizer, query=prompt, history=[], **self.kwargs ) return output