Unverified Commit bbac6760 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

fix torch version (#3225)

parent fa97a9ca
...@@ -9,6 +9,10 @@ from chatgpt.models.base import Actor ...@@ -9,6 +9,10 @@ from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai import colossalai
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
...@@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy): ...@@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy):
return model.module return model.module
return model return model
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self._unwrap_model(model) unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model # TODO : better way to get torch model from gemini model
# to get torch model from gemini model # to get torch model from gemini model
...@@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy): ...@@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy):
module.merge_weights=True module.merge_weights=True
module.eval() module.eval()
# get state_dict and save # get state_dict and save
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0: if not isinstance(self.model, PreTrainedModel):
return state_dict = unwrapped_model.state_dict()
torch.save(state_dict, path) if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else:
self.model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0: if only_rank0:
......
...@@ -3,5 +3,5 @@ tqdm ...@@ -3,5 +3,5 @@ tqdm
datasets datasets
loralib loralib
colossalai>=0.2.4 colossalai>=0.2.4
torch torch==1.12.1
langchain langchain
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment