Commit 7d06d0f9 authored by yangzhong's avatar yangzhong
Browse files

Update files

parent 2f320edb
Pipeline #2827 failed with stages
in 0 seconds
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import math
import re
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
from ..utils import PeftConfig, PeftType, transpose
def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None
if is_bnb_available():
import bitsandbytes as bnb
@dataclass
class LoraConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`~peft.Lora`].
Args:
r (`int`): Lora attention dimension
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
lora_alpha (`float`): The alpha parameter for Lora scaling.
lora_dropout (`float`): The dropout probability for Lora layers.
merge_weights (`bool`):
Whether to merge the weights of the Lora layers with the base transformer model in `eval` mode.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out)
enable_lora ( `List[bool]`): Used with `lora.MergedLinear`.
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
"""
r: int = field(default=8, metadata={"help": "Lora attention dimension"})
target_modules: Optional[Union[List[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with Lora."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
merge_weights: bool = field(
default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
)
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
def __post_init__(self):
self.peft_type = PeftType.LORA
class LoraModel(torch.nn.Module):
"""
Creates Low Rank Adapter (Lora) model from a pretrained transformers model.
Args:
model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`LoraConfig`]): The configuration of the Lora model.
Returns:
`torch.nn.Module`: The Lora model.
Example::
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import LoraModel, LoraConfig >>>
config = LoraConfig(
peft_type="LORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
lora_dropout=0.01, )
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> lora_model = LoraModel(config, model)
**Attributes**:
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
"""
def __init__(self, config, model):
super().__init__()
self.peft_config = config
self.model = model
self._find_and_replace()
mark_only_lora_as_trainable(self.model, self.peft_config.bias)
self.forward = self.model.forward
def _find_and_replace(self):
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
is_hf_device_map_available = hasattr(self.model, "hf_device_map")
kwargs = {
"r": self.peft_config.r,
"lora_alpha": self.peft_config.lora_alpha,
"lora_dropout": self.peft_config.lora_dropout,
"fan_in_fan_out": self.peft_config.fan_in_fan_out,
"merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
and not is_hf_device_map_available,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(self.peft_config.target_modules, str):
target_module_found = re.fullmatch(self.peft_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
bias = target.bias is not None
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
if self.peft_config.enable_lora is None:
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
else:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
new_module = MergedLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
elif loaded_in_4bit and isinstance(target, bnb.nn.Linear4bit):
kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
if self.peft_config.enable_lora is None:
new_module = Linear4bit(target.in_features, target.out_features, bias=bias, **kwargs)
else:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
new_module = MergedLinear4bit(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, torch.nn.Linear) and self.peft_config.enable_lora is None:
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
elif self.peft_config.enable_lora is not None:
kwargs.update({"enable_lora": self.peft_config.enable_lora})
if isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
else:
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is not a Conv1D. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = self.peft_config.fan_in_fan_out = False
new_module = MergedLinear(in_features, out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {self.peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _get_submodules(self, key):
parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
target_name = key.split(".")[-1]
target = self.model.get_submodule(key)
return parent, target, target_name
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
new_module.weight = old_module.weight
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
@property
def modules_to_save(self):
return None
def get_peft_config_as_dict(self, inference: bool = False):
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
if inference:
config["inference_mode"] = True
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, LoraLayer):
module.disable_adapters = False if enabled else True
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
# had to adapt it for `lora_only` to work
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
if bias == "none":
return
elif bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "lora_only":
for m in model.modules():
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError
class LoraLayer:
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
self.disable_adapters = False
class Linear(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False
def eval(self):
nn.Linear.eval(self)
self.lora_A.eval()
self.lora_B.eval()
def forward(self, x: torch.Tensor):
# previous_dtype = x.dtype
if self.disable_adapters:
if self.r > 0 and self.merged:
self.weight.data -= (
transpose(self.lora_B.weight @ self.lora_A.weight, self.fan_in_fan_out) * self.scaling
)
self.merged = False
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r > 0 and not self.merged:
# x.to(torch.float32)
# x = x.float()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
return result
class MergedLinear(nn.Linear, LoraLayer):
# Lora implemented in a dense layer
def __init__( # pylint: disable=W0102
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
self.lora_B = nn.Conv1d(
r * sum(enable_lora),
out_features // len(enable_lora) * sum(enable_lora),
kernel_size=1,
groups=2,
bias=False,
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora))
return result.view((*x.shape[:-1], self.out_features))
def train(self, mode: bool = True):
nn.Linear.train(self, mode)
self.lora_A.train(mode)
self.lora_B.train(mode)
if not mode and self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0 and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data += transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = True
elif self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0 and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False
def eval(self):
nn.Linear.eval(self)
self.lora_A.eval()
self.lora_B.eval()
def forward(self, x: torch.Tensor):
if self.disable_adapters:
if self.r > 0 and self.merged and any(self.enable_lora):
delta_w = (
F.conv1d(
self.lora_A.weight.data.unsqueeze(0),
self.lora_B.weight.data,
groups=sum(self.enable_lora),
)
.squeeze(0)
.transpose(-2, -1)
)
self.weight.data -= transpose(self.zero_pad(delta_w * self.scaling), not self.fan_in_fan_out)
self.merged = False
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.merged:
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.r > 0:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
if is_bnb_available():
class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
x = x.to(self.lora_A.weight.dtype)
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
return result
class MergedLinear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
# Lora implemented in a dense layer
def __init__( # pylint: disable=W0102
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
**kwargs,
):
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
self.lora_B = nn.Conv1d(
r * sum(enable_lora),
out_features // len(enable_lora) * sum(enable_lora),
kernel_size=1,
groups=2,
bias=False,
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features))
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
x = x.to(self.lora_A.weight.dtype)
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
output = self.zero_pad(after_B).to(expected_dtype) * self.scaling
result += output
else:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
output = self.zero_pad(after_B) * self.scaling
result += output
return result
class Linear4bit(bnb.nn.Linear4bit, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
x = x.to(self.lora_A.weight.dtype)
output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling
result += output
else:
output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
result += output
return result
class MergedLinear4bit(bnb.nn.Linear4bit, LoraLayer):
# Lora implemented in a dense layer
def __init__( # pylint: disable=W0102
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: List[bool] = [False],
**kwargs,
):
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
LoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
if out_features % len(enable_lora) != 0:
raise ValueError("The length of enable_lora must divide out_features")
self.enable_lora = enable_lora
# Actual trainable parameters
if r > 0 and any(enable_lora):
self.lora_A = nn.Linear(in_features, r * sum(enable_lora), bias=False)
self.lora_B = nn.Conv1d(
r * sum(enable_lora),
out_features // len(enable_lora) * sum(enable_lora),
kernel_size=1,
groups=2,
bias=False,
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# Compute the indices
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(len(enable_lora), -1)
self.lora_ind[enable_lora, :] = True
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def zero_pad(self, x):
result = x.new_zeros((*x.shape[:-1], self.out_features))
result = result.view(-1, self.out_features)
result[:, self.lora_ind] = x.reshape(
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
)
return result.view((*x.shape[:-1], self.out_features))
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters:
return result
elif self.r > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
x = x.to(self.lora_A.weight.dtype)
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
output = self.zero_pad(after_B).to(expected_dtype) * self.scaling
result += output
else:
after_A = self.lora_A(self.lora_dropout(x))
after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)
output = self.zero_pad(after_B) * self.scaling
result += output
return result
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import warnings
from dataclasses import dataclass, field
from typing import Union
import torch
from ..utils import PeftType, PromptLearningConfig
class PromptEncoderReparameterizationType(str, enum.Enum):
MLP = "MLP"
LSTM = "LSTM"
@dataclass
class PromptEncoderConfig(PromptLearningConfig):
"""
This is the configuration class to store the configuration of a [`~peft.PromptEncoder`].
Args:
encoder_reparameterization_type
(Union[[`PromptEncoderReparameterizationType`], `str`]): The type of reparameterization to use.
encoder_hidden_size (`int`): The hidden size of the prompt encoder.
encoder_num_layers (`int`): The number of layers of the prompt encoder.
encoder_dropout (`float`): The dropout probability of the prompt encoder.
"""
encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field(
default=PromptEncoderReparameterizationType.MLP,
metadata={"help": "How to reparameterize the prompt encoder"},
)
encoder_hidden_size: int = field(
default=None,
metadata={"help": "The hidden size of the prompt encoder"},
)
encoder_num_layers: int = field(
default=2,
metadata={"help": "The number of layers of the prompt encoder"},
)
encoder_dropout: float = field(
default=0.0,
metadata={"help": "The dropout of the prompt encoder"},
)
def __post_init__(self):
self.peft_type = PeftType.P_TUNING
# Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py
# with some refactor
class PromptEncoder(torch.nn.Module):
"""
The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
Args:
config ([`PromptEncoderConfig`]): The configuration of the prompt encoder.
Example::
>>> from peft import PromptEncoder, PromptEncoderConfig >>> config = PromptEncoderConfig(
peft_type="P_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768,
num_transformer_submodules=1, num_attention_heads=12, num_layers=12,
encoder_reparameterization_type="MLP", encoder_hidden_size=768
)
>>> prompt_encoder = PromptEncoder(config)
**Attributes**:
- **embedding** ([`~torch.nn.Embedding`]) -- The embedding layer of the prompt encoder.
- **mlp_head** ([`~torch.nn.Sequential`]) -- The MLP head of the prompt encoder if `inference_mode=False`.
- **lstm_head** ([`~torch.nn.LSTM`]) -- The LSTM head of the prompt encoder if `inference_mode=False` and
`encoder_reparameterization_type="LSTM"`.
- **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model.
- **input_size** (`int`) -- The input size of the prompt encoder.
- **output_size** (`int`) -- The output size of the prompt encoder.
- **hidden_size** (`int`) -- The hidden size of the prompt encoder.
- **total_virtual_tokens** (`int`): The total number of virtual tokens of the
prompt encoder.
- **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]):
The encoder type of the prompt encoder.
Input shape: (batch_size, total_virtual_tokens)
Output shape: (batch_size, total_virtual_tokens, token_dim)
"""
def __init__(self, config):
super().__init__()
self.token_dim = config.token_dim
self.input_size = self.token_dim
self.output_size = self.token_dim
self.hidden_size = config.encoder_hidden_size
self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.encoder_type = config.encoder_reparameterization_type
# embedding
self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim)
if not config.inference_mode:
if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
lstm_dropout = config.encoder_dropout
num_layers = config.encoder_num_layers
# LSTM
self.lstm_head = torch.nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=num_layers,
dropout=lstm_dropout,
bidirectional=True,
batch_first=True,
)
self.mlp_head = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size * 2, self.output_size),
)
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
warnings.warn(
f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
)
layers = [
torch.nn.Linear(self.input_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(self.hidden_size, self.output_size),
]
self.mlp_head = torch.nn.Sequential(*layers)
else:
raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
def forward(self, indices):
input_embeds = self.embedding(indices)
if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0])
elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
output_embeds = self.mlp_head(input_embeds)
else:
raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
return output_embeds
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
import torch
from ..utils import PeftType, PromptLearningConfig
@dataclass
class PrefixTuningConfig(PromptLearningConfig):
"""
This is the configuration class to store the configuration of a [`~peft.PrefixEncoder`].
Args:
encoder_hidden_size (`int`): The hidden size of the prompt encoder.
prefix_projection (`bool`): Whether to project the prefix embeddings.
"""
encoder_hidden_size: int = field(
default=None,
metadata={"help": "The hidden size of the encoder"},
)
prefix_projection: bool = field(
default=False,
metadata={"help": "Whether to project the prefix tokens"},
)
def __post_init__(self):
self.peft_type = PeftType.PREFIX_TUNING
# Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
# with some refactor
class PrefixEncoder(torch.nn.Module):
r"""
The torch.nn model to encode the prefix
Args:
config ([`PrefixTuningConfig`]): The configuration of the prefix encoder.
Example::
>>> from peft import PrefixEncoder, PrefixTuningConfig >>> config = PrefixTuningConfig(
peft_type="PREFIX_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768,
num_transformer_submodules=1, num_attention_heads=12, num_layers=12, encoder_hidden_size=768
)
>>> prefix_encoder = PrefixEncoder(config)
**Attributes**:
- **embedding** (`torch.nn.Embedding`) --
The embedding layer of the prefix encoder.
- **transform** (`torch.nn.Sequential`) -- The
two-layer MLP to transform the prefix embeddings if `prefix_projection` is `True`.
- **prefix_projection** (`bool`) -- Whether to project the prefix embeddings.
Input shape: (batch_size, num_virtual_tokens)
Output shape: (batch_size, num_virtual_tokens, 2*layers*hidden)
"""
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
token_dim = config.token_dim
num_layers = config.num_layers
encoder_hidden_size = config.encoder_hidden_size
num_virtual_tokens = config.num_virtual_tokens
if self.prefix_projection and not config.inference_mode:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
self.transform = torch.nn.Sequential(
torch.nn.Linear(token_dim, encoder_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
)
else:
self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.transform(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import math
from dataclasses import dataclass, field
from typing import Optional, Union
import torch
from ..utils import PeftType, PromptLearningConfig
class PromptTuningInit(str, enum.Enum):
TEXT = "TEXT"
RANDOM = "RANDOM"
@dataclass
class PromptTuningConfig(PromptLearningConfig):
"""
This is the configuration class to store the configuration of a [`~peft.PromptEmbedding`].
Args:
prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding.
prompt_tuning_init_text ( Optional[`str`]): The text to initialize the prompt embedding.
Only used if `prompt_tuning_init` is `TEXT`
tokenizer_name_or_path ( Optional[`str`]): The name or path of the tokenizer.
Only used if `prompt_tuning_init` is `TEXT`
"""
prompt_tuning_init: Union[PromptTuningInit, str] = field(
default=PromptTuningInit.RANDOM,
metadata={"help": "How to initialize the prompt tuning parameters"},
)
prompt_tuning_init_text: Optional[str] = field(
default=None,
metadata={
"help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
},
)
def __post_init__(self):
self.peft_type = PeftType.PROMPT_TUNING
class PromptEmbedding(torch.nn.Module):
"""
The model to encode virtual tokens into prompt embeddings.
Args:
config ([`PromptTuningConfig`]): The configuration of the prompt embedding.
word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model.
**Attributes**:
**embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding.
Example::
>>> from peft import PromptEmbedding, PromptTuningConfig >>> config = PromptTuningConfig(
peft_type="PROMPT_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768,
num_transformer_submodules=1, num_attention_heads=12, num_layers=12, prompt_tuning_init="TEXT",
prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
tokenizer_name_or_path="t5-base",
)
>>> # t5_model.shared is the word embeddings of the base model >>> prompt_embedding = PromptEmbedding(config,
t5_model.shared)
Input Shape: (batch_size, total_virtual_tokens)
Output Shape: (batch_size, total_virtual_tokens, token_dim)
"""
def __init__(self, config, word_embeddings):
super().__init__()
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
if config.prompt_tuning_init == PromptTuningInit.TEXT:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
init_text = config.prompt_tuning_init_text
init_token_ids = tokenizer(init_text)["input_ids"]
# Trim or iterate until num_text_tokens matches total_virtual_tokens
num_text_tokens = len(init_token_ids)
if num_text_tokens > total_virtual_tokens:
init_token_ids = init_token_ids[:total_virtual_tokens]
elif num_text_tokens < total_virtual_tokens:
num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
init_token_ids = init_token_ids * num_reps
init_token_ids = init_token_ids[:total_virtual_tokens]
word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
def forward(self, indices):
# Just get embeddings
prompt_embeddings = self.embedding(indices)
return prompt_embeddings
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME
from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
from .other import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
_set_trainable,
bloom_model_postprocess_past_key_value,
# prepare_model_for_int8_training,
shift_tokens_right,
transpose,
)
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
WEIGHTS_NAME = "adapter_model.bin"
CONFIG_NAME = "adapter_config.json"
# TODO: add automapping and superclass here?
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import json
import os
from dataclasses import asdict, dataclass, field
from typing import Optional, Union
from huggingface_hub import hf_hub_download
from transformers.utils import PushToHubMixin
from .adapters_utils import CONFIG_NAME
class PeftType(str, enum.Enum):
PROMPT_TUNING = "PROMPT_TUNING"
P_TUNING = "P_TUNING"
PREFIX_TUNING = "PREFIX_TUNING"
LORA = "LORA"
class TaskType(str, enum.Enum):
SEQ_CLS = "SEQ_CLS"
SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
CAUSAL_LM = "CAUSAL_LM"
@dataclass
class PeftConfigMixin(PushToHubMixin):
r"""
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
PEFT adapter models. This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
Args:
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
"""
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
@property
def __dict__(self):
return asdict(self)
def to_dict(self):
return self.__dict__
def save_pretrained(self, save_directory, **kwargs):
r"""
This method saves the configuration of your adapter model in a directory.
Args:
save_directory (`str`):
The directory where the configuration will be saved.
**kwargs:
Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub`
method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
output_dict = self.__dict__
output_path = os.path.join(save_directory, CONFIG_NAME)
# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
This method loads the configuration of your adapter model from a directory.
Args:
pretrained_model_name_or_path (`str`):
The directory or the hub-id where the configuration is saved.
**kwargs:
Additional keyword arguments passed along to the child class initialization.
"""
if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME)
except Exception:
raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")
loaded_attributes = cls.from_json_file(config_file)
config = cls(**kwargs)
for key, value in loaded_attributes.items():
if hasattr(config, key):
setattr(config, key, value)
return config
@classmethod
def from_json_file(cls, path_json_file, **kwargs):
r"""
Loads a configuration file from a json file.
Args:
path_json_file (`str`):
The path to the json file.
"""
with open(path_json_file, "r") as file:
json_object = json.load(file)
return json_object
@dataclass
class PeftConfig(PeftConfigMixin):
"""
This is the base configuration class to store the configuration of a :class:`~peft.PeftModel`.
Args:
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
"""
base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."})
peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"})
task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
@dataclass
class PromptLearningConfig(PeftConfig):
"""
This is the base configuration class to store the configuration of a Union[[`~peft.PrefixTuning`],
[`~peft.PromptEncoder`], [`~peft.PromptTuning`]].
Args:
num_virtual_tokens (`int`): The number of virtual tokens to use.
token_dim (`int`): The hidden embedding dimension of the base transformer model.
num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
num_attention_heads (`int`): The number of attention heads in the base transformer model.
num_layers (`int`): The number of layers in the base transformer model.
"""
num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
token_dim: int = field(
default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
)
num_transformer_submodules: Optional[int] = field(
default=None, metadata={"help": "Number of transformer submodules"}
)
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
past_key_values = torch.cat(past_key_values)
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
keys = past_key_values[: total_layers // 2]
keys = keys.transpose(2, 3).reshape(
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
)
values = past_key_values[total_layers // 2 :]
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
return tuple(zip(keys, values))
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
"bloom": bloom_model_postprocess_past_key_value,
}
# copied from transformers.models.bart.modeling_bart
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
pad_token_id (`int`): The id of the `padding` token.
decoder_start_token_id (`int`): The id of the `start` token.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def _set_trainable(model):
if model.modules_to_save is not None:
for name, param in model.named_parameters():
if any(module_name in name for module_name in model.modules_to_save):
param.requires_grad = True
def fsdp_auto_wrap_policy(model):
import functools
import os
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
FullyShardedDataParallelPlugin.get_module_class_from_name(
model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
),
),
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy
def transpose(weight, fan_in_fan_out):
return weight.T if fan_in_fan_out else weight
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import PeftType
def get_peft_model_state_dict(model, state_dict=None):
"""
Get the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
the model should be the underlying model/unwrapped model (i.e. model.module).
state_dict (`dict`, *optional*, defaults to `None`):
The state dict of the model. If not provided, the state dict of the model
will be used.
"""
if state_dict is None:
state_dict = model.state_dict()
if model.peft_config.peft_type == PeftType.LORA:
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = model.peft_config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
else:
to_return = {}
if model.peft_config.inference_mode:
prompt_embeddings = model.prompt_encoder.embedding.weight
else:
prompt_embeddings = model.get_prompt_embedding_to_save()
to_return["prompt_embeddings"] = prompt_embeddings
if model.modules_to_save is not None:
for key, value in state_dict.items():
if any(module_name in key for module_name in model.modules_to_save):
to_return[key] = value
return to_return
def set_peft_model_state_dict(model, peft_model_state_dict):
"""
Set the state dict of the Peft model.
Args:
model ([`PeftModel`]): The Peft model.
peft_model_state_dict (`dict`): The state dict of the Peft model.
"""
model.load_state_dict(peft_model_state_dict, strict=False)
if model.peft_config.peft_type != PeftType.LORA:
model.prompt_encoder.embedding.load_state_dict(
{"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
)
return model
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import logging
import numpy as np
import math
import os
import sys
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional, List, Dict, Any, Mapping
from pathlib import Path
import datasets
import torch
from datasets import load_dataset, concatenate_datasets
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig,
AutoModelForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
set_seed,
BitsAndBytesConfig
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
from sklearn.metrics import accuracy_score
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, get_peft_model_state_dict
from peft.tuners.lora import LoraLayer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(transformers.TrainerCallback):
def save_model(self, args, state, kwargs):
if state.best_model_checkpoint is not None:
checkpoint_folder = os.path.join(state.best_model_checkpoint, "pt_lora_model")
else:
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "pt_lora_model")
kwargs["model"].save_pretrained(peft_model_path)
kwargs["tokenizer"].save_pretrained(peft_model_path)
def on_save(self, args, state, control, **kwargs):
self.save_model(args, state, kwargs)
return control
def on_train_end(self, args, state, control, **kwargs):
peft_model_path = os.path.join(args.output_dir, "pt_lora_model")
kwargs["model"].save_pretrained(peft_model_path)
kwargs["tokenizer"].save_pretrained(peft_model_path)
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
r"""
This method wraps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
head to fp32
Args:
model, (`transformers.PreTrainedModel`):
The loaded model from `transformers`
"""
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
for name, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
# cast all non INT8/INT4 parameters to fp32
for param in model.parameters():
if ((param.dtype == torch.float16) or (param.dtype == torch.bfloat16)) and loaded_in_kbit:
param.data = param.data.to(torch.float32)
for name, module in model.named_modules():
if 'norm' in name:
module = module.to(torch.float32)
if loaded_in_kbit and use_gradient_checkpointing:
# For backward compatibility
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, _input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
return model
def accuracy(predictions, references, normalize=True, sample_weight=None):
return {
"accuracy": float(
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
)
}
def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics but we need to shift the labels
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
return accuracy(predictions=preds, references=labels)
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
def fault_tolerance_data_collator(features: List) -> Dict[str, Any]:
if not isinstance(features[0], Mapping):
features = [vars(f) for f in features]
first = features[0]
batch = {}
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if "label" in first and first["label"] is not None:
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
dtype = torch.long if isinstance(label, int) else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else:
dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
try:
for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
else:
batch[k] = torch.tensor([f[k] for f in features])
except ValueError: # quick fix by simply take the first example
for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([features[0][k]] * len(features))
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([features[0][k]] * len(features)))
else:
batch[k] = torch.tensor([features[0][k]] * len(features))
return batch
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": (
"Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
)
},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_dir: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
block_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[float] = field(
default=0.05,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
keep_linebreaks: bool = field(
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
)
data_cache_dir: Optional[str] = field(default="./", metadata={"help": "The datasets processed stored"})
def __post_init__(self):
if self.streaming:
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
@dataclass
class MyTrainingArguments(TrainingArguments):
trainable : Optional[str] = field(default="q_proj,v_proj")
lora_rank : Optional[int] = field(default=8)
lora_dropout : Optional[float] = field(default=0.1)
lora_alpha : Optional[float] = field(default=32.)
modules_to_save : Optional[str] = field(default=None)
debug_mode : Optional[bool] = field(default=False)
peft_path : Optional[str] = field(default=None)
use_flash_attention_2 : Optional[bool] = field(default=False)
double_quant: Optional[bool] = field(default=True)
quant_type: Optional[str] = field(default="nf4")
load_in_kbits: Optional[int] = field(default=16)
full_finetuning : Optional[bool] = field(default=False)
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clm", model_args, data_args)
# Setup logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN,
handlers=[logging.StreamHandler(sys.stdout)],)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# transformers.tokenization_utils.logging.set_verbosity_warning()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
logger.info(f"New config: {config}")
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
"use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
elif model_args.tokenizer_name_or_path:
tokenizer = LlamaTokenizer.from_pretrained(model_args.tokenizer_name_or_path, **tokenizer_kwargs)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
tokenizer.add_eos_token = True
# Preprocessing the datasets.
# First we tokenize all the texts.
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
def tokenize_function(examples):
with CaptureLogger(tok_logger) as cl:
output = tokenizer(examples["text"])
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
" before being passed to the model."
)
return output
if data_args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > 1024:
logger.warning(
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
" override this default with `--block_size xxx`."
)
block_size = 1024
else:
if data_args.block_size > tokenizer.model_max_length:
logger.warning(
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
)
block_size = min(data_args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
with training_args.main_process_first(desc="dataset map tokenization and grouping"):
lm_datasets = []
path = Path(data_args.dataset_dir)
files = [file.name for file in path.glob("*.txt")]
if training_args.debug_mode is True:
files = [files[0]]
for idx, file in enumerate(files):
data_file = os.path.join(path, file)
filename = ''.join(file.split(".")[:-1])
cache_path = os.path.join(data_args.data_cache_dir, filename+f"_{block_size}")
os.makedirs(cache_path, exist_ok=True)
try:
processed_dataset = datasets.load_from_disk(cache_path, keep_in_memory=False)
logger.info(f'training datasets-{filename} has been loaded from disk')
except Exception:
cache_dir = os.path.join(data_args.data_cache_dir, filename+f"_text_{block_size}")
os.makedirs(cache_dir, exist_ok=True)
raw_dataset = load_dataset("text", data_files=data_file, cache_dir=cache_dir, keep_in_memory=False)
logger.info(f"{file} has been loaded")
tokenized_dataset = raw_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns="text",
load_from_cache_file=True,
keep_in_memory=False,
cache_file_names = {k: os.path.join(cache_dir, 'tokenized.arrow') for k in raw_dataset},
desc="Running tokenizer on dataset",
)
grouped_datasets = tokenized_dataset.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=True,
keep_in_memory=False,
cache_file_names = {k: os.path.join(cache_dir, 'grouped.arrow') for k in tokenized_dataset},
desc=f"Grouping texts in chunks of {block_size}",
)
processed_dataset = grouped_datasets
processed_dataset.save_to_disk(cache_path)
if idx == 0:
lm_datasets = processed_dataset['train']
else:
assert lm_datasets.features.type == processed_dataset["train"].features.type
lm_datasets = concatenate_datasets([lm_datasets, processed_dataset["train"]])
lm_datasets = lm_datasets.train_test_split(test_size = data_args.validation_split_percentage)
if training_args.do_train:
train_dataset = lm_datasets['train']
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.info(f"Num train_samples {len(train_dataset)}")
logger.info("Training example:")
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
if training_args.do_eval:
eval_dataset = lm_datasets["test"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
logger.info(f"Num eval_samples {len(eval_dataset)}")
logger.info("Evaluation example:")
logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
if training_args.load_in_kbits in [4, 8]:
load_in_4bit = training_args.load_in_kbits == 4
load_in_8bit = training_args.load_in_kbits == 8
if training_args.modules_to_save is not None:
load_in_8bit_skip_modules = training_args.modules_to_save.split(',')
else:
load_in_8bit_skip_modules = None
quantization_config = BitsAndBytesConfig(
load_in_4bit=training_args.load_in_kbits == 4,
load_in_8bit=training_args.load_in_kbits == 8,
llm_int8_threshold=6.0,
load_in_8bit_skip_modules=load_in_8bit_skip_modules,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
else:
load_in_4bit = False
load_in_8bit = False
quantization_config = None
if quantization_config is not None:
logger.info(f"quantization_config:{quantization_config.to_dict()}")
if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
device_map = {"":int(os.environ.get("LOCAL_RANK") or 0)}
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
quantization_config=quantization_config,
use_flash_attention_2=training_args.use_flash_attention_2
)
else:
model = AutoModelForCausalLM.from_config(config)
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
if training_args.load_in_kbits in [4, 8]:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
model.config.use_cache = False
model_vocab_size = model.get_output_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
logger.info(f"Model vocab size: {model_vocab_size}")
logger.info(f"Tokenizer vocab size: {tokenizer_vocab_size}")
if tokenizer_vocab_size != 55296:
raise ValueError(f"The vocab size of tokenizer is {tokenizer_vocab_size}, not 55296. Please use Chinese-LLaMA-2 tokenizer.")
if model_vocab_size != tokenizer_vocab_size:
logger.info(f"Resize model vocab size to {tokenizer_vocab_size}")
model.resize_token_embeddings(len(tokenizer))
if not training_args.full_finetuning:
if training_args.peft_path is not None:
logger.info("Peft from pre-trained model")
model = PeftModel.from_pretrained(model, training_args.peft_path, device_map=device_map)
else:
logger.info("Init new peft model")
target_modules = training_args.trainable.split(',')
modules_to_save = training_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
lora_rank = training_args.lora_rank
lora_dropout = training_args.lora_dropout
lora_alpha = training_args.lora_alpha
logger.info(f"target_modules: {target_modules}")
logger.info(f"lora_rank: {lora_rank}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=lora_rank, lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
modules_to_save=modules_to_save)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
logger.info(f"model.modules_to_save: {model.modules_to_save}")
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
if not training_args.full_finetuning and training_args.gradient_checkpointing and \
(not model.modules_to_save or 'embed_tokens' not in model.modules_to_save):
# enable requires_grad to avoid exception during backward pass when using gradient_checkpoint without tuning embed.
if hasattr(model.base_model, "enable_input_require_grads"):
model.base_model.enable_input_require_grads()
elif hasattr(model.base_model, "get_input_embeddings"):
def make_inputs_require_grad(_module, _input, _output):
_output.requires_grad_(True)
model.base_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=fault_tolerance_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
)
trainer.add_callback(SavePeftModelCallback)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment