Unverified Commit c21b11ed authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

change nn to models (#3032)

parent 4269196c
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
from .actor import Actor from ..base import Actor
class OPTActor(Actor): class OPTActor(Actor):
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel from transformers.models.opt.modeling_opt import OPTModel
from .critic import Critic from ..base import Critic
class OPTCritic(Critic): class OPTCritic(Critic):
......
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers import OPTConfig, OPTModel from transformers import OPTConfig, OPTModel
from .reward_model import RewardModel from ..base import RewardModel
class OPTRM(RewardModel): class OPTRM(RewardModel):
......
from .actor import Actor
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
from .critic import Critic
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
from .reward_model import RewardModel
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss', 'GPTActor',
'GPTCritic', 'GPTRM', 'BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'OPTActor', 'OPTCritic', 'OPTRM'
]
...@@ -2,8 +2,9 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -2,8 +2,9 @@ from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss from chatgpt.models.base import Actor, Critic
from chatgpt.nn.generation_utils import update_model_kwargs_fn from chatgpt.models.generation_utils import update_model_kwargs_fn
from chatgpt.models.loss import PolicyLoss, ValueLoss
from chatgpt.replay_buffer import NaiveReplayBuffer from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
......
...@@ -3,7 +3,7 @@ from abc import ABC ...@@ -3,7 +3,7 @@ from abc import ABC
import loralib as lora import loralib as lora
import torch import torch
from chatgpt.dataset import RewardDataset from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss from chatgpt.models.loss import PairWiseLoss
from torch.optim import Adam, Optimizer from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
......
...@@ -5,7 +5,7 @@ from typing import Any, List, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import Any, List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import Actor from chatgpt.models.base import Actor, Critic, RewardModel
from chatgpt.replay_buffer import ReplayBuffer from chatgpt.replay_buffer import ReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from chatgpt.nn import Actor from chatgpt.models.base import Actor
from torch.optim import Optimizer from torch.optim import Optimizer
import colossalai import colossalai
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import Actor from chatgpt.models.base import Actor
from chatgpt.replay_buffer import ReplayBuffer from chatgpt.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
......
import argparse import argparse
import torch import torch
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor from chatgpt.models.bloom import BLOOMActor
from chatgpt.models.gpt import GPTActor
from chatgpt.models.opt import OPTActor
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
......
...@@ -2,7 +2,10 @@ import argparse ...@@ -2,7 +2,10 @@ import argparse
from copy import deepcopy from copy import deepcopy
import torch import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import SaveCheckpoint from chatgpt.trainer.callbacks import SaveCheckpoint
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
......
...@@ -3,7 +3,10 @@ from copy import deepcopy ...@@ -3,7 +3,10 @@ from copy import deepcopy
import pandas as pd import pandas as pd
import torch import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
......
...@@ -3,7 +3,10 @@ import argparse ...@@ -3,7 +3,10 @@ import argparse
import loralib as lora import loralib as lora
import torch import torch
from chatgpt.dataset import RewardDataset from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM from chatgpt.models.base import RewardModel
from chatgpt.models.bloom import BLOOMRM
from chatgpt.models.gpt import GPTRM
from chatgpt.models.opt import OPTRM
from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset from datasets import load_dataset
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from chatgpt.nn import GPTActor from chatgpt.models.gpt import GPTActor
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
......
...@@ -7,7 +7,8 @@ import torch ...@@ -7,7 +7,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from chatgpt.experience_maker import NaiveExperienceMaker from chatgpt.experience_maker import NaiveExperienceMaker
from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.models.base import RewardModel
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.replay_buffer import NaiveReplayBuffer from chatgpt.replay_buffer import NaiveReplayBuffer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment