Commit 0938ae70 authored by zhaoying1's avatar zhaoying1
Browse files

fix save method of adapter_model.bin

parent 1b73554f
......@@ -11,7 +11,7 @@ from peft import (
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.tuner.core.utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
......@@ -52,9 +52,6 @@ def init_adapter(
else:
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None
......@@ -81,13 +78,18 @@ def init_adapter(
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else:
target_modules = finetuning_args.lora_target
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
target_modules=target_modules
)
model = get_peft_model(model, lora_config)
......
......@@ -4,6 +4,7 @@ import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
......@@ -15,14 +16,19 @@ from transformers import (
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.deepspeed import is_deepspeed_zero3_enabled
except ImportError:
from transformers.integrations import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.misc import count_parameters
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
from llmtuner.tuner.core.utils import prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
......@@ -32,11 +38,11 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
# check_min_version("4.29.1")
# check_min_version("4.30.0")
# require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
# require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
# require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
# require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
# require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
# require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
def load_model_and_tokenizer(
......@@ -64,18 +70,23 @@ def load_model_and_tokenizer(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side=model_args.padding_side,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
# Fix tokenizer (for ChatGLM2)
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
# Fix config (for Qwen)
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
......@@ -91,11 +102,12 @@ def load_model_and_tokenizer(
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:
if model_args.rope_scaling == "dynamic":
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
......@@ -118,6 +130,15 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support RoPE scaling.")
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
from llmtuner.extras.models.flash_llama import LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
if not hasattr(config, "num_key_value_heads"):
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1:
setattr(config, "pretraining_tp", 1)
# Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None:
......@@ -171,10 +192,12 @@ def load_model_and_tokenizer(
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
......
......@@ -5,6 +5,7 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
......@@ -95,6 +96,9 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
if general_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
......@@ -110,9 +114,17 @@ def get_train_args(
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if general_args.stage == "ppo" and training_args.deepspeed is not None:
raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.")
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
......@@ -166,6 +178,7 @@ def get_train_args(
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
# require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
......@@ -186,18 +199,6 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float16
# transfer training stage to dataset stage
dataset_stage = general_args.stage
if general_args.stage == "ppo":
dataset_stage = "sft"
elif general_args.stage == "dpo":
dataset_stage = "rm"
for dataset_attr in data_args.dataset_list:
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
raise ValueError("Dataset {} is not supported for the stage {}"
.format(dataset_attr.dataset_name, general_args.stage))
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:
......@@ -223,6 +224,9 @@ def get_infer_args(
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
......
import os
import torch
from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftModelMixin:
r"""
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
def __init__(self) -> None: # for type checking
self.model: PreTrainedModel = None
self.tokenizer: "PreTrainedTokenizer" = None
self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
self.state: "TrainerState" = None
raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()
v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model.pretrained_model
state_dict = state_dict or get_state_dict(model)
if isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
model.config.use_cache = False
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
try:
self.tokenizer.save_pretrained(output_dir)
except:
logger.warning("Cannot save tokenizer, copy the files manually.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
model.v_head.load_state_dict(torch.load(
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
))
model = model.pretrained_model
if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args
import torch
from typing import TYPE_CHECKING, List, Optional
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None,
output_layer_name: Optional[str] = "lm_head"
) -> List[str]:
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
else:
linear_cls = torch.nn.Linear
module_names = set()
for name, module in model.named_modules():
if output_layer_name not in name and isinstance(module, linear_cls):
module_names.add(name.split(".")[-1])
if output_layer_name in module_names:
module_names.pop(output_layer_name)
return list(module_names)
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment