Unverified Commit 773955ab authored by Yuanchen's avatar Yuanchen Committed by GitHub
Browse files

fix save_model inin naive and ddp strategy (#3436)


Co-authored-by: default avatarYuanchen Xu <yuanchen.xu00@gmail.com>
parent 1beb85cc
from typing import Optional
import os import os
import random import random
...@@ -5,12 +7,13 @@ import numpy as np ...@@ -5,12 +7,13 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from coati.models.base import Actor from coati.models.base import LM, Actor, RewardModel
from coati.models.lora import LoraLinear from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy from .base import Strategy
from .naive import NaiveStrategy from .naive import NaiveStrategy
...@@ -72,16 +75,31 @@ class DDPStrategy(NaiveStrategy): ...@@ -72,16 +75,31 @@ class DDPStrategy(NaiveStrategy):
model: DDP = Strategy._unwrap_actor(actor) model: DDP = Strategy._unwrap_actor(actor)
return model.module return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
for module in model.modules(): for module in model.modules():
if isinstance(module, LoraLinear): if isinstance(module, LoraLinear):
module.merge_weights = True module.merge_weights = True
module.eval() module.eval()
if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
return return
model = model.model.module torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict() state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path) torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
......
from typing import Any from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from coati.models.base import LM, RewardModel
from coati.models.lora import LoraLinear
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy from .base import Strategy
...@@ -38,9 +41,25 @@ class NaiveStrategy(Strategy): ...@@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self._unwrap_model(model) for module in model.modules():
torch.save(unwrapped_model.state_dict(), path) if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model) unwrapped_model = self._unwrap_model(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment