import os import sys from typing import Optional, Dict import logging import torch from accelerate import Accelerator from datasets import load_dataset from transformers import AutoModelForCausalLM, HfArgumentParser, BitsAndBytesConfig, AutoTokenizer from dataclasses import dataclass, field logger = logging.getLogger(__name__) @dataclass class ModelArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ model_name_or_path: str = field( default=None, metadata={"help": "The name of the model to use (via the transformers library) for the prompt annotation."}, ) model_variant: str = field( default=None, metadata={"help": "If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. "}, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, ) torch_dtype: Optional[str] = field( default="float16", metadata={ "help": ( "Floating-point format in which the model weights should be initialized" " and the computations run. Choose one of `[float32, float16, bfloat16]`." ) }, ) attn_implementation: Optional[str] = field( default="sdpa", metadata={"help": "Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']"}, ) load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "Whether to use 8-bit precision for inference."}) load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "Whether to use 4-bit precision for inference."}) bnb_4bit_quant_type: Optional[str] = field( default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} ) use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"}) trust_remote_code: Optional[bool] = field( default=False, metadata={ "help": ( "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " "should only be set to `True` for repositories you trust and in which you have read the code, as it will " "execute code present on the Hub on your local machine." ) }, ) use_fast_tokenizer: Optional[bool] = field(default=True, metadata={"help": "Use fast tokenizer for encoding/decoding input ids"}) @dataclass class DataArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ dataset_name: str = field( default=None, metadata={ "help": "The name of the dataset to use (via the datasets library)" }, ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, ) dataset_split_name: Optional[str] = field( default=None, metadata={"help": "The split name of the dataset to use (via the datasets library)."}, ) dataset_cache_dir: Optional[str] = field( default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}, ) samples_per_dataset: Optional[int] = field( default=None, metadata={"help": "Number of samples per dataset used to measure speed."}, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}, ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None: if model_args.load_in_4bit: compute_dtype = torch.float16 if model_args.torch_dtype not in {"auto", None}: compute_dtype = getattr(torch, model_args.torch_dtype) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, ) elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, ) else: quantization_config = None return quantization_config def get_current_device() -> int: """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" def get_kbit_device_map() -> Dict[str, int] | None: """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" return {"": get_current_device()} if torch.cuda.is_available() else None def main(): # 1. Parse input arguments parser = HfArgumentParser((ModelArguments, DataArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args = parser.parse_args_into_dataclasses() # 2. Setup logging # Make one log on every process with the configuration for debugging. logger.setLevel(logging.INFO) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) # 3. Load pre-trained model logger.info("*** Load pretrained model ***") torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, revision=model_args.model_revision, variant=model_args.model_variant, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, low_cpu_mem_usage=True, ) tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, use_fast=model_args.use_fast_tokenizer, ) # 4. Load annotation dataset raw_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config)