"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "1b34701027f4654bd5c543330b8969c0b001c68c"
Unverified Commit 901ab1ee authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[chat]: add lora merge weights config (#4766)

* feat: modify lora merge weights fn

* feat: add lora merge weights config
parent 493a5efe
import dataclasses
import math import math
import warnings
from typing import Optional from typing import Optional
import loralib as lora import loralib as lora
...@@ -7,6 +9,14 @@ import torch.nn as nn ...@@ -7,6 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@dataclasses.dataclass
class LoRAManager:
merge_weights: bool = False
LORA_MANAGER = LoRAManager()
class LoraLinear(lora.LoRALayer, nn.Module): class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
...@@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0.0, lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True, fan_in_fan_out: bool = False,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
lora.LoRALayer.__init__( lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
...@@ -53,8 +61,11 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -53,8 +61,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode) self.training = mode
if self.merge_weights and self.merged: if LORA_MANAGER.merge_weights:
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
# Make sure that the weights are not merged # Make sure that the weights are not merged
if self.r > 0: if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
...@@ -65,13 +76,8 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -65,13 +76,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
else: else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False self.merged = False
elif not mode and not self.merged:
def eval(self): warnings.warn("Invoke module.eval() would merge LoRA weights.")
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it # Merge the weights and mark it
if self.r > 0: if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
...@@ -79,6 +85,8 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -79,6 +85,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
delattr(self, "lora_B") delattr(self, "lora_B")
self.merged = True self.merged = True
return self
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: ...@@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert ( assert (
lora_rank <= linear.in_features lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear return lora_linear
......
...@@ -192,6 +192,12 @@ def main(args): ...@@ -192,6 +192,12 @@ def main(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
actor.eval()
# save model checkpoint after fitting # save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True) strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -227,6 +233,7 @@ if __name__ == "__main__": ...@@ -227,6 +233,7 @@ if __name__ == "__main__":
parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=1e-7) parser.add_argument("--lr", type=float, default=1e-7)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9) parser.add_argument("--ptx_coef", type=float, default=0.9)
......
...@@ -157,6 +157,13 @@ def train(args): ...@@ -157,6 +157,13 @@ def train(args):
log_dir=args.log_dir, log_dir=args.log_dir,
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True) strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -186,6 +193,7 @@ if __name__ == "__main__": ...@@ -186,6 +193,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default="logs", type=str)
......
...@@ -177,6 +177,12 @@ def train(args): ...@@ -177,6 +177,12 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
...@@ -204,6 +210,7 @@ if __name__ == "__main__": ...@@ -204,6 +210,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str) parser.add_argument("--log_dir", default="logs", type=str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment