Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -56,8 +56,7 @@ class ChatGLMConfig(PretrainedConfig): ...@@ -56,8 +56,7 @@ class ChatGLMConfig(PretrainedConfig):
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
``` ```"""
"""
model_type = "chatglm" model_type = "chatglm"
def __init__( def __init__(
...@@ -79,7 +78,7 @@ class ChatGLMConfig(PretrainedConfig): ...@@ -79,7 +78,7 @@ class ChatGLMConfig(PretrainedConfig):
quantization_bit=0, quantization_bit=0,
pre_seq_len=None, pre_seq_len=None,
prefix_projection=False, prefix_projection=False,
**kwargs **kwargs,
): ):
self.num_layers = num_layers self.num_layers = num_layers
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig): ...@@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig):
self.pre_seq_len = pre_seq_len self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection self.prefix_projection = prefix_projection
super().__init__( super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
\ No newline at end of file
...@@ -16,9 +16,9 @@ except ImportError: ...@@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def _prepare_logits_processor(top_k: Optional[int] = None, def _prepare_logits_processor(
top_p: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
temperature: Optional[float] = None) -> LogitsProcessorList: ) -> LogitsProcessorList:
processor_list = LogitsProcessorList() processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0: if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature)) processor_list.append(TemperatureLogitsWarper(temperature))
...@@ -37,7 +37,8 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: ...@@ -37,7 +37,8 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0 return unfinished_sequences.max() == 0
def _sample(model: Actor, def _sample(
model: Actor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
max_length: int, max_length: int,
early_stopping: bool = False, early_stopping: bool = False,
...@@ -48,7 +49,8 @@ def _sample(model: Actor, ...@@ -48,7 +49,8 @@ def _sample(model: Actor,
temperature: Optional[float] = None, temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor: **model_kwargs,
) -> torch.Tensor:
if input_ids.size(1) >= max_length: if input_ids.size(1) >= max_length:
return input_ids return input_ids
...@@ -56,11 +58,12 @@ def _sample(model: Actor, ...@@ -56,11 +58,12 @@ def _sample(model: Actor,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length): for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \ model_inputs = (
if prepare_inputs_fn is not None else {'input_ids': input_ids} prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
)
outputs = model(**model_inputs) outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :] next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution # pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits) next_token_logits = logits_processor(input_ids, next_token_logits)
# sample # sample
...@@ -90,7 +93,8 @@ def _sample(model: Actor, ...@@ -90,7 +93,8 @@ def _sample(model: Actor,
@torch.no_grad() @torch.no_grad()
def generate(model: Actor, def generate(
model: Actor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
max_length: int, max_length: int,
num_beams: int = 1, num_beams: int = 1,
...@@ -103,7 +107,8 @@ def generate(model: Actor, ...@@ -103,7 +107,8 @@ def generate(model: Actor,
temperature: Optional[float] = None, temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor: **model_kwargs,
) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens. """Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args: Args:
...@@ -121,15 +126,16 @@ def generate(model: Actor, ...@@ -121,15 +126,16 @@ def generate(model: Actor,
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
""" """
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = ((num_beams == 1) and do_sample is True) is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = ((num_beams > 1) and do_sample is False) is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode: if is_greedy_gen_mode:
# run greedy search # run greedy search
raise NotImplementedError raise NotImplementedError
elif is_sample_gen_mode: elif is_sample_gen_mode:
# run sample # run sample
return _sample(model, return _sample(
model,
input_ids, input_ids,
max_length, max_length,
early_stopping=early_stopping, early_stopping=early_stopping,
...@@ -140,7 +146,8 @@ def generate(model: Actor, ...@@ -140,7 +146,8 @@ def generate(model: Actor,
temperature=temperature, temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn, prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn, update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs) **model_kwargs,
)
elif is_beam_gen_mode: elif is_beam_gen_mode:
raise NotImplementedError raise NotImplementedError
else: else:
......
...@@ -2,4 +2,4 @@ from .gpt_actor import GPTActor ...@@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] __all__ = ["GPTActor", "GPTCritic", "GPTRM"]
...@@ -18,13 +18,15 @@ class GPTActor(Actor): ...@@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer. lora_train_bias (str): Bias training strategy for the LoRa layer.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
checkpoint: bool = False, checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = "none",
**kwargs) -> None: **kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained) model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -18,12 +18,14 @@ class GPTCritic(Critic): ...@@ -18,12 +18,14 @@ class GPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = "none",
**kwargs) -> None: **kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained) model = GPT2Model.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -18,11 +18,13 @@ class GPTRM(RewardModel): ...@@ -18,11 +18,13 @@ class GPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None, config: Optional[GPT2Config] = None,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained) model = GPT2Model.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -2,4 +2,4 @@ from .llama_actor import LlamaActor ...@@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] __all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
from typing import Optional from typing import Optional
import torch from transformers import LlamaConfig, LlamaForCausalLM
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from ..base import Actor from ..base import Actor
...@@ -18,13 +17,14 @@ class LlamaActor(Actor): ...@@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None, config: Optional[LlamaConfig] = None,
checkpoint: bool = False, checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained) model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -17,13 +17,14 @@ class LlamaCritic(Critic): ...@@ -17,13 +17,14 @@ class LlamaCritic(Critic):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None, config: Optional[LlamaConfig] = None,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = "none",
**kwargs) -> None: **kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
from typing import Optional from typing import Optional
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel from ..base import RewardModel
...@@ -17,12 +17,13 @@ class LlamaRM(RewardModel): ...@@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None, config: Optional[LlamaConfig] = None,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -8,8 +8,7 @@ import torch.nn.functional as F ...@@ -8,8 +8,7 @@ import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module): class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
"""
def __init__( def __init__(
self, self,
...@@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module):
bias: Optional[nn.Parameter], bias: Optional[nn.Parameter],
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 1,
lora_dropout: float = 0., lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True, merge_weights: bool = True,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
lora.LoRALayer.__init__(self, lora.LoRALayer.__init__(
r=r, self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
lora_alpha=lora_alpha, )
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
...@@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.weight.data = self.weight.data.T self.weight.data = self.weight.data.T
def reset_parameters(self): def reset_parameters(self):
if hasattr(self, 'lora_A'): if hasattr(self, "lora_A"):
# Initialize A with the default values for nn.Linear and set B to zero. # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True): def train(self, mode: bool = True):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.merged = False self.merged = False
def eval(self): def eval(self):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it # Merge the weights and mark it
if self.r > 0: if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A') delattr(self, "lora_A")
delattr(self, 'lora_B') delattr(self, "lora_B")
self.merged = True self.merged = True
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
def T(w): def T(w):
return w.T if self.fan_in_fan_out else w return w.T if self.fan_in_fan_out else w
...@@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module): ...@@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear return lora_linear
...@@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: ...@@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
_convert_to_lora_recursively(child, lora_rank) _convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module. """Convert a torch.nn.Module to a LoRA module.
Args: Args:
...@@ -140,7 +136,7 @@ class LoRAModule(nn.Module): ...@@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'. Defaults to 'none'.
""" """
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__() super().__init__()
self.lora_rank = lora_rank self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias self.lora_train_bias = lora_train_bias
......
...@@ -31,11 +31,13 @@ class PolicyLoss(nn.Module): ...@@ -31,11 +31,13 @@ class PolicyLoss(nn.Module):
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps = clip_eps
def forward(self, def forward(
self,
log_probs: torch.Tensor, log_probs: torch.Tensor,
old_log_probs: torch.Tensor, old_log_probs: torch.Tensor,
advantages: torch.Tensor, advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp() ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
...@@ -55,14 +57,16 @@ class ValueLoss(nn.Module): ...@@ -55,14 +57,16 @@ class ValueLoss(nn.Module):
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps = clip_eps
def forward(self, def forward(
self,
values: torch.Tensor, values: torch.Tensor,
old_values: torch.Tensor, old_values: torch.Tensor,
reward: torch.Tensor, reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2 surr1 = (values_clipped - reward) ** 2
surr2 = (values - reward)**2 surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2) loss = torch.max(surr1, surr2)
loss = loss.mean() loss = loss.mean()
return 0.5 * loss return 0.5 * loss
......
...@@ -2,4 +2,4 @@ from .opt_actor import OPTActor ...@@ -2,4 +2,4 @@ from .opt_actor import OPTActor
from .opt_critic import OPTCritic from .opt_critic import OPTCritic
from .opt_rm import OPTRM from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] __all__ = ["OPTActor", "OPTCritic", "OPTRM"]
...@@ -18,12 +18,14 @@ class OPTActor(Actor): ...@@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None, config: Optional[OPTConfig] = None,
checkpoint: bool = False, checkpoint: bool = False,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained) model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -18,12 +18,14 @@ class OPTCritic(Critic): ...@@ -18,12 +18,14 @@ class OPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None, config: Optional[OPTConfig] = None,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none', lora_train_bias: str = "none",
**kwargs) -> None: **kwargs,
) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTModel.from_pretrained(pretrained) model = OPTModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -17,11 +17,13 @@ class OPTRM(RewardModel): ...@@ -17,11 +17,13 @@ class OPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode. lora_train_bias (str): LoRA bias training mode.
""" """
def __init__(self, def __init__(
self,
pretrained: Optional[str] = None, pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None, config: Optional[OPTConfig] = None,
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = "none",
) -> None:
if pretrained is not None: if pretrained is not None:
model = OPTModel.from_pretrained(pretrained) model = OPTModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
......
...@@ -4,9 +4,9 @@ import torch ...@@ -4,9 +4,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
def _compute_approx_kl(log_probs: torch.Tensor, def _compute_approx_kl(
log_probs_base: torch.Tensor, log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute the approximate KL divergence between two distributions. Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html Schulman blog: http://joschu.net/blog/kl-approx.html
...@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor, ...@@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
return approx_kl return approx_kl
def compute_reward(r: Union[torch.Tensor, float], def compute_reward(
r: Union[torch.Tensor, float],
kl_coef: float, kl_coef: float,
log_probs: torch.Tensor, log_probs: torch.Tensor,
log_probs_base: torch.Tensor, log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if kl_coef <= 0.0: if kl_coef <= 0.0:
return r return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
...@@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num ...@@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
Returns: Returns:
torch.Tensor: Action log probs. torch.Tensor: Action log probs.
""" """
logits = output['logits'] logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:] return log_probs[:, -num_actions:]
......
...@@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant ...@@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
from .utils import low_resource_init from .utils import low_resource_init
__all__ = [ __all__ = [
'llama_load_quant', "llama_load_quant",
'low_resource_init', "low_resource_init",
] ]
from .loader import load_quant from .loader import load_quant
__all__ = [ __all__ = [
'load_quant', "load_quant",
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment