Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
from .deberta_critic import DebertaCritic
from .deberta_rm import DebertaRM
__all__ = ['DebertaCritic', 'DebertaRM']
from typing import Optional
import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model
from ..base import Critic
class DebertaCritic(Critic):
"""
Deberta Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers import DebertaV2Config, DebertaV2Model
from ..base import RewardModel
class DebertaRM(RewardModel):
"""
Deberta Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (DebertaV2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[DebertaV2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = DebertaV2Model.from_pretrained(pretrained)
elif config is not None:
model = DebertaV2Model(config)
else:
model = DebertaV2Model(DebertaV2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
processor_list.append(TopKLogitsWarper(top_k))
if top_p is not None and top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
return processor_list
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
dist.all_reduce(unfinished_sequences)
return unfinished_sequences.max() == 0
def sample(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
'input_ids': input_ids
}
outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and _is_sequence_finished(unfinished_sequences):
break
return input_ids
def generate(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
model (nn.Module): model
input_ids (torch.Tensor): input sequence
max_length (int): max length of the returned sequence
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
elif is_beam_gen_mode:
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")
from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
from typing import Optional
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from ..base import Actor
class GPTActor(Actor):
"""
GPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LoRa layer.
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:
model = GPT2LMHeadModel(config)
else:
model = GPT2LMHeadModel(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from ..base import Critic
class GPTCritic(Critic):
"""
GPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from ..base import RewardModel
class GPTRM(RewardModel):
"""
GPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:
model = GPT2Model(config)
else:
model = GPT2Model(GPT2Config())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from ..base import Actor
class LlamaActor(Actor):
"""
Llama Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
else:
model = LlamaForCausalLM(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers import LlamaConfig, LlamaModel
from ..base import Critic
class LlamaCritic(Critic):
"""
Llama Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
model = LlamaModel(config)
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from ..base import RewardModel
class LlamaRM(RewardModel):
"""
Llama Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
model = LlamaModel(config)
else:
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
import math
from typing import Optional
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def __init__(
self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight
self.bias = bias
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.
Returns:
nn.Module: The converted module.
"""
if lora_rank <= 0:
return module
convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This class will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
from typing import Optional
import torch
import torch.nn as nn
from .utils import masked_mean
class GPTLMLoss(nn.Module):
"""
GPT Language Model Loss
"""
def __init__(self):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
loss = loss.mean()
return loss
class ValueLoss(nn.Module):
"""
Value Loss for PPO
"""
def __init__(self, clip_eps: float = 0.4) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2
surr2 = (values - reward)**2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return 0.5 * loss
class PPOPtxActorLoss(nn.Module):
"""
To Do:
PPO-ptx Actor Loss
"""
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
super().__init__()
self.pretrain_coef = pretrain_coef
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
self.pretrain_loss_fn = pretrain_loss_fn
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
lm_logits: torch.Tensor,
lm_input_ids: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
return policy_loss + self.pretrain_coef * lm_loss
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2203.02155
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward)
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss
class LogExpLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2204.05862
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss
from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
from typing import Optional
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from ..base import Actor
class OPTActor(Actor):
"""
OPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
from typing import Optional
import torch.nn as nn
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTModel
from ..base import Critic
class OPTCritic(Critic):
"""
OPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
from typing import Optional
import torch.nn as nn
from transformers import OPTConfig, OPTModel
from ..base import RewardModel
class OPTRM(RewardModel):
"""
OPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:
model = OPTModel(config)
else:
model = OPTModel(OPTConfig())
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
from .roberta_actor import RoBERTaActor
from .roberta_critic import RoBERTaCritic
from .roberta_rm import RoBERTaRM
__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM']
\ No newline at end of file
from typing import Optional
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.models.roberta.modeling_roberta import RobertaForCausalLM
from ..base import Actor
class RoBERTaActor(Actor):
"""
RoBERTa Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (RoBERTaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[RobertaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = RobertaForCausalLM.from_pretrained(pretrained)
elif config is not None:
model = RobertaForCausalLM(config)
else:
model = RobertaForCausalLM(RobertaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment