Commit 42b5aac6 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

start llm prompts

parent 98482e58
import os
import sys
from typing import Optional, Dict
import logging
import torch
from accelerate import Accelerator
from datasets import load_dataset
from transformers import AutoModelForCausalLM, HfArgumentParser, BitsAndBytesConfig, AutoTokenizer
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
model_name_or_path: str = field(
default=None,
metadata={"help": "The name of the model to use (via the transformers library) for the prompt annotation."},
)
model_variant: str = field(
default=None,
metadata={"help": "If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. "},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
torch_dtype: Optional[str] = field(
default="float16",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized"
" and the computations run. Choose one of `[float32, float16, bfloat16]`."
)
},
)
attn_implementation: Optional[str] = field(
default="sdpa",
metadata={"help": "Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']"},
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "Whether to use 8-bit precision for inference."})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "Whether to use 4-bit precision for inference."})
bnb_4bit_quant_type: Optional[str] = field(
default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
)
use_bnb_nested_quant: Optional[bool] = field(default=False, metadata={"help": "use nested quantization"})
trust_remote_code: Optional[bool] = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
use_fast_tokenizer: Optional[bool] = field(default=True, metadata={"help": "Use fast tokenizer for encoding/decoding input ids"})
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
default=None,
metadata={
"help": "The name of the dataset to use (via the datasets library)"
},
)
dataset_config_name: Optional[str] = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
)
dataset_split_name: Optional[str] = field(
default=None,
metadata={"help": "The split name of the dataset to use (via the datasets library)."},
)
dataset_cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to cache directory for saving and loading datasets"},
)
samples_per_dataset: Optional[int] = field(
default=None,
metadata={"help": "Number of samples per dataset used to measure speed."},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None:
if model_args.load_in_4bit:
compute_dtype = torch.float16
if model_args.torch_dtype not in {"auto", None}:
compute_dtype = getattr(torch, model_args.torch_dtype)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def get_current_device() -> int:
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
def get_kbit_device_map() -> Dict[str, int] | None:
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
return {"": get_current_device()} if torch.cuda.is_available() else None
def main():
# 1. Parse input arguments
parser = HfArgumentParser((ModelArguments, DataArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args = parser.parse_args_into_dataclasses()
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# 3. Load pre-trained model
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
variant=model_args.model_variant,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
)
# 4. Load annotation dataset
raw_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment