Unverified Commit c9dd0365 authored by BlueRum's avatar BlueRum Committed by GitHub
Browse files

[chatgpt] fix lora save bug (#3099)

* fix colo-stratergy

* polish

* fix lora

* fix ddp

* polish

* polish
parent 018936a3
......@@ -74,6 +74,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True
def forward(self, x: torch.Tensor):
......@@ -125,3 +127,4 @@ class LoRAModule(nn.Module):
return
convert_to_lora_recursively(self, self.lora_rank)
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
......@@ -6,11 +6,13 @@ import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer
import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
......@@ -143,6 +145,20 @@ class ColossalAIStrategy(DDPStrategy):
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if isinstance(unwrapped_model, ZeroDDP):
state_dict = unwrapped_model.state_dict()
unwrapped_model = get_static_torch_model(unwrapped_model)
if only_rank0 and dist.get_rank() != 0:
return
unwrapped_model.load_state_dict(state_dict)
# merge lora_weights into weights
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.eval()
# get state_dict and save
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
......
......@@ -6,6 +6,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
......@@ -72,10 +73,17 @@ class DDPStrategy(NaiveStrategy):
return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.eval()
if only_rank0 and dist.get_rank() != 0:
return
super().save_model(model, path, only_rank0)
model = model.model.module
state_dict = model.state_dict()
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment