#!/usr/bin/env python # -*- coding: utf-8 -*- from pathlib import Path from typing import Annotated, Union import typer from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, ) ModelType = Union[PreTrainedModel, PeftModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] app = typer.Typer(pretty_exceptions_show_locals=False) def _resolve_path(path: Union[str, Path]) -> Path: return Path(path).expanduser().resolve() def load_model_and_tokenizer(model_dir: Union[str, Path]) -> tuple[ModelType, TokenizerType]: model_dir = _resolve_path(model_dir) if (model_dir / 'adapter_config.json').exists(): model = AutoPeftModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, device_map='auto' ) tokenizer_dir = model.peft_config['default'].base_model_name_or_path else: model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, device_map='auto' ) tokenizer_dir = model_dir tokenizer = AutoTokenizer.from_pretrained( tokenizer_dir, trust_remote_code=True ) return model, tokenizer @app.command() def main( model_dir: Annotated[str, typer.Argument(help='')], prompt: Annotated[str, typer.Option(help='')], ): model, tokenizer = load_model_and_tokenizer(model_dir) response, _ = model.chat(tokenizer, prompt) print(response) if __name__ == '__main__': app()