Commit 84987715 authored by chenych's avatar chenych
Browse files

update to v0.9.2

parent 317a82e2
......@@ -5,7 +5,7 @@ accelerate>=0.34.0,<=1.2.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
tokenizers>=0.19.0,<=0.21.0
gradio>=4.38.0,<=5.18.0
gradio>=4.38.0,<=5.21.0
pandas>=2.0.0
scipy
einops
......
......@@ -38,7 +38,7 @@ def vllm_infer(
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 2048,
max_samples: int = None,
max_samples: Optional[int] = None,
vllm_config: str = "{}",
save_name: str = "generated_predictions.jsonl",
temperature: float = 0.95,
......@@ -46,6 +46,7 @@ def vllm_infer(
top_k: int = 50,
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
......@@ -97,19 +98,21 @@ def vllm_infer(
multi_modal_data = None
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
)
)
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k,
top_k=generating_args.top_k or -1, # top_k must > 0
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=False,
skip_special_tokens=skip_special_tokens,
seed=seed,
)
if model_args.adapter_name_or_path is not None:
......@@ -121,6 +124,7 @@ def vllm_infer(
"model": model_args.model_name_or_path,
"trust_remote_code": True,
"dtype": model_args.infer_dtype,
"max_model_len": cutoff_len + max_new_tokens,
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
"pipeline_parallel_size": pipeline_parallel_size,
"disable_log_stats": True,
......
......@@ -46,7 +46,7 @@ extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.2"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
"liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"],
......
......@@ -21,6 +21,7 @@ from typing import Optional
from typing_extensions import Annotated
from ..chat import ChatModel
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import (
......@@ -60,7 +61,7 @@ async def sweeper() -> None:
@asynccontextmanager
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
if chat_model.engine.name == EngineName.HF:
asyncio.create_task(sweeper())
yield
......@@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if request.stream:
generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream")
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
else:
return await create_chat_completion_response(request, chat_model)
......
......@@ -23,6 +23,7 @@ if TYPE_CHECKING:
from ..data import Template
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..extras.constants import EngineName
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
......@@ -41,6 +42,7 @@ class BaseEngine(ABC):
Must implements async methods: chat(), stream_chat() and get_scores().
"""
name: "EngineName"
model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer"
can_generate: bool
......
......@@ -20,6 +20,7 @@ import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.constants import EngineName
from ..extras.misc import torch_gc
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
......@@ -47,10 +48,9 @@ class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface":
if model_args.infer_backend == EngineName.HF:
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
......
......@@ -24,7 +24,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
......@@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.HF
self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"]
......
......@@ -19,7 +19,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
......@@ -49,6 +49,7 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.name = EngineName.VLLM
self.model_args = model_args
config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16
......@@ -169,7 +170,7 @@ class VllmEngine(BaseEngine):
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
top_k=(top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0
stop=stop,
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
max_tokens=max_tokens,
......
......@@ -88,18 +88,24 @@ def main():
elif command == Command.TRAIN:
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if force_torchrun or (get_device_count() > 1 and not use_ray()):
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
if int(nnodes) > 1:
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
process = subprocess.run(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
)
.format(
nnodes=os.getenv("NNODES", "1"),
node_rank=os.getenv("NODE_RANK", "0"),
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
nnodes=nnodes,
node_rank=node_rank,
nproc_per_node=nproc_per_node,
master_addr=master_addr,
master_port=master_port,
file_name=launcher.__file__,
......@@ -119,7 +125,7 @@ def main():
elif command == Command.HELP:
print(USAGE)
else:
raise NotImplementedError(f"Unknown command: {command}.")
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
......
......@@ -43,7 +43,7 @@ class Role(str, Enum):
class DatasetModule(TypedDict):
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]]
def merge_dataset(
......@@ -54,11 +54,13 @@ def merge_dataset(
"""
if len(all_datasets) == 1:
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
......@@ -69,24 +71,75 @@ def merge_dataset(
seed=seed,
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
)
else:
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
dataset: Optional[Union["Dataset", "IterableDataset"]],
eval_dataset: Optional[Union["Dataset", "IterableDataset", Dict[str, "Dataset"]]],
data_args: "DataArguments",
seed: int,
) -> "DatasetDict":
r"""
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
return DatasetDict({"train": train_set, "validation": val_set})
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
if eval_dataset is not None and data_args.val_size > 1e-6:
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
dataset_dict = {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
if data_args.val_size > 1e-6:
if data_args.streaming:
dataset_dict["validation"] = dataset.take(int(data_args.val_size))
dataset_dict["train"] = dataset.skip(int(data_args.val_size))
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset_dict = dataset.train_test_split(test_size=val_size, seed=seed)
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
dataset_dict = {"train": dataset["train"], "validation": dataset["test"]}
else:
dataset_dict["train"] = dataset
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
dataset_dict["validation"] = eval_dataset
return DatasetDict(dataset_dict)
def get_dataset_module(dataset: Union["Dataset", "DatasetDict"]) -> "DatasetModule":
r"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module: "DatasetModule" = {}
if isinstance(dataset, DatasetDict): # dataset dict
if "train" in dataset:
dataset_module["train_dataset"] = dataset["train"]
if "validation" in dataset:
dataset_module["eval_dataset"] = dataset["validation"]
else:
eval_dataset = {}
for key in dataset.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset[key]
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
else: # single dataset
dataset_module["train_dataset"] = dataset
return dataset_module
......@@ -121,7 +121,7 @@ class FunctionFormatter(StringFormatter):
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(1) + function_str
function_str = thought.group(0) + function_str
return super().apply(content=function_str)
......
......@@ -13,17 +13,16 @@
# limitations under the License.
import os
import sys
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets import load_dataset, load_from_disk
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data
from .converter import align_dataset
from .data_utils import merge_dataset, split_dataset
from .data_utils import get_dataset_module, merge_dataset, split_dataset
from .parser import get_dataset_list
from .processor import (
FeedbackDatasetProcessor,
......@@ -292,23 +291,12 @@ def get_dataset(
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
tokenized_data: Union["Dataset", "DatasetDict"] = load_from_disk(data_args.tokenized_path)
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if isinstance(tokenized_data, DatasetDict):
if "train" in tokenized_data:
dataset_module["train_dataset"] = tokenized_data["train"]
if "validation" in tokenized_data:
dataset_module["eval_dataset"] = tokenized_data["validation"]
else: # single dataset
dataset_module["train_dataset"] = tokenized_data
tokenized_data = load_from_disk(data_args.tokenized_path)
dataset_module = get_dataset_module(tokenized_data)
if data_args.streaming:
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
dataset_module["train_dataset"] = dataset_module["train_dataset"].to_iterable_dataset()
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
return dataset_module
if data_args.streaming:
......@@ -335,48 +323,11 @@ def get_dataset(
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
)
if data_args.val_size > 1e-6:
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
else:
dataset_dict = {}
if dataset is not None:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["train"] = dataset
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
dataset_dict.update({f"validation_{name}": data for name, data in eval_dataset.items()})
else:
if data_args.streaming:
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
dataset_dict["validation"] = eval_dataset
dataset_dict = DatasetDict(dataset_dict)
if data_args.tokenized_path is not None: # save tokenized dataset to disk and exit
dataset_dict = split_dataset(dataset, eval_dataset, data_args, seed=training_args.seed)
if data_args.tokenized_path is not None: # save tokenized dataset to disk
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)
dataset_module = {}
if "train" in dataset_dict:
dataset_module["train_dataset"] = dataset_dict["train"]
if "validation" in dataset_dict:
dataset_module["eval_dataset"] = dataset_dict["validation"]
else:
eval_dataset = {}
for key in dataset_dict.keys():
if key.startswith("validation_"):
eval_dataset[key[len("validation_") :]] = dataset_dict[key]
if len(eval_dataset):
dataset_module["eval_dataset"] = eval_dataset
logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.")
return dataset_module
return get_dataset_module(dataset_dict)
......@@ -96,12 +96,31 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
class AttentionFunction(str, Enum):
AUTO = "auto"
DISABLED = "disabled"
SDPA = "sdpa"
FA2 = "fa2"
class EngineName(str, Enum):
HF = "huggingface"
VLLM = "vllm"
class DownloadSource(str, Enum):
DEFAULT = "hf"
MODELSCOPE = "ms"
OPENMIND = "om"
class RopeScaling(str, Enum):
LINEAR = "linear"
DYNAMIC = "dynamic"
YARN = "yarn"
LLAMA3 = "llama3"
def register_model_group(
models: Dict[str, Dict[DownloadSource, str]],
template: Optional[str] = None,
......
......@@ -26,7 +26,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.9.2.dev0"
VERSION = "0.9.2"
def print_env() -> None:
......@@ -74,4 +74,13 @@ def print_env() -> None:
except Exception:
pass
try:
import subprocess
commit_info = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True)
commit_hash = commit_info.stdout.strip()
info["Git commit"] = commit_hash
except Exception:
pass
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
......@@ -363,15 +363,15 @@ class SwanLabArguments:
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: str = field(
swanlab_project: Optional[str] = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
swanlab_workspace: Optional[str] = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: str = field(
swanlab_run_name: Optional[str] = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
......@@ -379,15 +379,19 @@ class SwanLabArguments:
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
swanlab_api_key: Optional[str] = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
swanlab_logdir: Optional[str] = field(
default=None,
metadata={"help": "The log directory for SwanLab."},
)
@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
......@@ -415,15 +419,15 @@ class FinetuningArguments(
)
freeze_vision_tower: bool = field(
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
)
freeze_multi_modal_projector: bool = field(
default=True,
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
)
train_mm_proj_only: bool = field(
freeze_language_model: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
metadata={"help": "Whether or not to freeze the language model in MLLM training."},
)
compute_accuracy: bool = field(
default=False,
......@@ -455,8 +459,6 @@ class FinetuningArguments(
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
......@@ -484,9 +486,6 @@ class FinetuningArguments(
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
......
......@@ -23,6 +23,164 @@ import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass
class BaseModelArguments:
r"""
Arguments pertaining to the model.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
)
split_special_tokens: bool = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
low_cpu_mem_usage: bool = field(
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[RopeScaling] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.AUTO,
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's gradient checkpointing (no need to install unsloth)."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
use_reentrant_gc: bool = field(
default=True,
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: EngineName = field(
default=EngineName.HF,
metadata={"help": "Backend engine used at inference."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
)
def __post_init__(self):
if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
@dataclass
class QuantizationArguments:
......@@ -127,6 +285,10 @@ class ExportArguments:
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
def __post_init__(self):
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
@dataclass
class VllmArguments:
......@@ -155,148 +317,19 @@ class VllmArguments:
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
)
def __post_init__(self):
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
@dataclass
class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments):
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
The class on the most right will be displayed first.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
)
split_special_tokens: bool = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
low_cpu_mem_usage: bool = field(
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
use_unsloth_gc: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's gradient checkpointing."},
)
enable_liger_kernel: bool = field(
default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
use_reentrant_gc: bool = field(
default=True,
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface",
metadata={"help": "Backend engine used at inference."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
om_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Modelers Hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
......@@ -319,23 +352,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
)
def __post_init__(self):
if self.model_name_or_path is None:
raise ValueError("Please provide `model_name_or_path`.")
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
BaseModelArguments.__post_init__(self)
ExportArguments.__post_init__(self)
VllmArguments.__post_init__(self)
@classmethod
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
......
......@@ -382,10 +382,10 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
# Log on each process the small summary
logger.info(
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
"Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format(
training_args.process_index,
training_args.world_size,
training_args.device,
training_args.n_gpu,
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
......@@ -418,7 +418,8 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
model_args.model_max_length = data_args.cutoff_len
if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default
model_args.model_max_length = data_args.cutoff_len
else:
model_args.device_map = "auto"
......
......@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras import logging
from ...extras.constants import AttentionFunction
from ...extras.misc import check_version
......@@ -33,34 +34,34 @@ def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available():
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
if model_args.flash_attn != AttentionFunction.FA2:
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = AttentionFunction.FA2
else:
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa":
model_args.flash_attn = AttentionFunction.DISABLED
elif model_args.flash_attn == AttentionFunction.SDPA:
logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == "auto":
if model_args.flash_attn == AttentionFunction.AUTO:
return
elif model_args.flash_attn == "disabled":
elif model_args.flash_attn == AttentionFunction.DISABLED:
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
elif model_args.flash_attn == AttentionFunction.SDPA:
if not is_torch_sdpa_available():
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
elif model_args.flash_attn == AttentionFunction.FA2:
if not is_flash_attn_2_available():
logger.warning_rank0("FlashAttention-2 is not installed.")
return
......
......@@ -20,6 +20,7 @@ import math
from typing import TYPE_CHECKING
from ...extras import logging
from ...extras.constants import RopeScaling
if TYPE_CHECKING:
......@@ -39,33 +40,32 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.")
return
rope_kwargs = {}
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic":
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
rope_kwargs["factor"] = 1.0
if (not current_max_length) or model_args.model_max_length <= current_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
if model_args.rope_scaling == "dynamic":
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
if model_args.rope_scaling == RopeScaling.DYNAMIC:
rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == "llama3":
elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
else:
rope_kwargs["factor"] = 2.0
setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs})
setattr(config, "rope_scaling", rope_kwargs)
logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
)
......@@ -166,7 +166,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
forbidden_modules.add(projector_key)
if finetuning_args.train_mm_proj_only:
if finetuning_args.freeze_language_model:
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
forbidden_modules.update(language_model_keys)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment