Unverified Commit c21b11ed authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

change nn to models (#3032)

parent 4269196c
...@@ -41,7 +41,8 @@ Simplest usage: ...@@ -41,7 +41,8 @@ Simplest usage:
```python ```python
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy from chatgpt.trainer.strategies import ColossalAIStrategy
from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.models.base import RewardModel
from copy import deepcopy from copy import deepcopy
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
......
...@@ -4,7 +4,8 @@ from copy import deepcopy ...@@ -4,7 +4,8 @@ from copy import deepcopy
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.models.base import RewardModel
from chatgpt.models.gpt import GPTActor, GPTCritic
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
......
...@@ -4,7 +4,8 @@ from copy import deepcopy ...@@ -4,7 +4,8 @@ from copy import deepcopy
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel from chatgpt.models.base import RewardModel
from chatgpt.models.opt import OPTActor, OPTCritic
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
......
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from chatgpt.nn.actor import Actor from chatgpt.models.base import Actor
@dataclass @dataclass
......
import torch import torch
from chatgpt.nn.utils import compute_reward, normalize from chatgpt.models.utils import compute_reward, normalize
from .base import Experience, ExperienceMaker from .base import Experience, ExperienceMaker
......
from .base import Actor, Critic, RewardModel
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss']
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
__all__ = ['Actor', 'Critic', 'RewardModel']
...@@ -4,9 +4,9 @@ import torch ...@@ -4,9 +4,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .generation import generate from ..generation import generate
from .lora import LoRAModule from ..lora import LoRAModule
from .utils import log_probs_from_logits from ..utils import log_probs_from_logits
class Actor(LoRAModule): class Actor(LoRAModule):
......
...@@ -3,8 +3,8 @@ from typing import Optional ...@@ -3,8 +3,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .lora import LoRAModule from ..lora import LoRAModule
from .utils import masked_mean from ..utils import masked_mean
class Critic(LoRAModule): class Critic(LoRAModule):
......
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .lora import LoRAModule from ..lora import LoRAModule
class RewardModel(LoRAModule): class RewardModel(LoRAModule):
......
from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .actor import Actor from ..base import Actor
class BLOOMActor(Actor): class BLOOMActor(Actor):
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .critic import Critic from ..base import Critic
class BLOOMCritic(Critic): class BLOOMCritic(Critic):
......
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel from transformers import BloomConfig, BloomForCausalLM, BloomModel
from .reward_model import RewardModel from ..base import RewardModel
class BLOOMRM(RewardModel): class BLOOMRM(RewardModel):
......
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from .actor import Actor from ..base import Actor
class GPTActor(Actor): class GPTActor(Actor):
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from .critic import Critic from ..base import Critic
class GPTCritic(Critic): class GPTCritic(Critic):
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from .reward_model import RewardModel from ..base import RewardModel
class GPTRM(RewardModel): class GPTRM(RewardModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment