Unverified Commit 901ab1ee authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat]: add lora merge weights config (#4766)

* feat: modify lora merge weights fn

* feat: add lora merge weights config
parent 493a5efe
import dataclasses
import math import math
import warnings
from typing import Optional from typing import Optional
import loralib as lora import loralib as lora
...@@ -7,6 +9,14 @@ import torch.nn as nn ...@@ -7,6 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@dataclasses.dataclass
class LoRAManager:
merge_weights: bool = False
LORA_MANAGER = LoRAManager()
class LoraLinear(lora.LoRALayer, nn.Module): class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
...@@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True, fan_in_fan_out: bool = False,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
lora.LoRALayer.__init__( lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
...@@ -53,31 +61,31 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -53,31 +61,31 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode) self.training = mode
if self.merge_weights and self.merged: if LORA_MANAGER.merge_weights:
# Make sure that the weights are not merged if mode and self.merged:
if self.r > 0: warnings.warn("Invoke module.train() would unmerge LoRA weights.")
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): raise NotImplementedError("LoRA unmerge is not tested.")
# FIXME(csric): temporary fix # Make sure that the weights are not merged
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) if self.r > 0:
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
self.reset_parameters() # FIXME(csric): temporary fix
else: self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.merged = False self.reset_parameters()
else:
def eval(self): self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
def T(w): self.merged = False
return w.T if self.fan_in_fan_out else w elif not mode and not self.merged:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
nn.Module.eval(self) # Merge the weights and mark it
if self.merge_weights and not self.merged: if self.r > 0:
# Merge the weights and mark it self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
if self.r > 0: delattr(self, "lora_A")
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling delattr(self, "lora_B")
delattr(self, "lora_A") self.merged = True
delattr(self, "lora_B")
self.merged = True return self
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
def T(w): def T(w):
...@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: ...@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert ( assert (
lora_rank <= linear.in_features lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear return lora_linear
......
...@@ -192,6 +192,12 @@ def main(args): ...@@ -192,6 +192,12 @@ def main(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
actor.eval()
# save model checkpoint after fitting # save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True) strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -227,6 +233,7 @@ if __name__ == "__main__": ...@@ -227,6 +233,7 @@ if __name__ == "__main__":
parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=1e-7) parser.add_argument("--lr", type=float, default=1e-7)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9) parser.add_argument("--ptx_coef", type=float, default=0.9)
......
...@@ -157,6 +157,13 @@ def train(args): ...@@ -157,6 +157,13 @@ def train(args):
log_dir=args.log_dir, log_dir=args.log_dir,
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True) strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -186,6 +193,7 @@ if __name__ == "__main__": ...@@ -186,6 +193,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default="logs", type=str)
......
...@@ -177,6 +177,12 @@ def train(args): ...@@ -177,6 +177,12 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -204,6 +210,7 @@ if __name__ == "__main__": ...@@ -204,6 +210,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default="logs", type=str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment