Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": 1e5,
"wall_clock_breakdown": false
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"zero_quantized_weights": true,
"zero_hpz_partition_size": 16,
"zero_quantized_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
\ No newline at end of file
# flake8: noqa
__version__ = "0.7.11.dev0"
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import (
is_bitsandbytes_available,
is_diffusers_available,
is_npu_available,
is_peft_available,
is_wandb_available,
is_xpu_available,
)
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
DPOTrainer,
IterativeSFTTrainer,
ModelConfig,
PPOConfig,
PPOTrainer,
RewardConfig,
RewardTrainer,
SFTTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
if is_diffusers_available():
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
from .trainer import DDPOConfig, DDPOTrainer
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
# from transformers import top_k_top_p_filtering
from .import_utils import is_npu_available, is_xpu_available
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
WANDB_PADDING = -1
def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
return logits
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
"""Flatten dictionary and concatenate nested keys with separator."""
def recurse(nest: Dict, prefix: str, into: Dict) -> None:
for k, v in nest.items():
if sep in k:
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
if isinstance(v, Mapping):
recurse(v, prefix + k + sep, into)
else:
into[prefix + k] = v
flat = {}
recurse(nested, "", flat)
return flat
def convert_to_scalar(stats: Dict) -> Dict:
"""
Converts the stats from a flattened dict to single scalar dicts
"""
tensorboard_stats = {}
for k, v in stats.items():
# for tensorboard compatibility - arrays and tensors are ignored with tensorboard
# therefore we convert single element tensors to scalars
if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)):
v = v.item()
tensorboard_stats[k] = v
return tensorboard_stats
def stack_dicts(stats_dicts: List[Dict]) -> Dict:
"""Stack the values of a dict."""
results = dict()
for k in stats_dicts[0]:
stats_list = [torch.flatten(d[k]) for d in stats_dicts]
results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
return results
def add_suffix(input_dict: Dict, suffix: str) -> Dict:
"""Add suffix to dict keys."""
return dict((k + suffix, v) for k, v in input_dict.items())
def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor:
"""Pad tensor to size."""
t_size = tensor.size()[dim]
if t_size == size:
return tensor
else:
return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding)
def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
logp = F.log_softmax(logits, dim=2)
if not gather:
return logp
logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
return logpy
def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values."""
mean, var = torch.mean(values), torch.var(values)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
else:
return (values * mask).sum() / mask.sum()
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError("The sum of the mask is zero, which can happen when `mini_batch_size=1`;" "try increase the `mini_batch_size` or `gradient_accumulation_steps`")
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
"""
Tensor extension to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
return entropy
def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict:
"""Average values of a list of dicts with torch tensors."""
average_dict = dict()
for key in list_of_dicts[0].keys():
average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
return average_dict
def stats_to_np(stats_dict: Dict) -> Dict:
"""Cast all torch.tensors in dict to numpy arrays."""
new_dict = dict()
for k, v in stats_dict.items():
if isinstance(v, torch.Tensor):
new_dict[k] = v.detach().cpu()
if new_dict[k].dtype == torch.bfloat16:
new_dict[k] = new_dict[k].float()
new_dict[k] = new_dict[k].numpy()
else:
new_dict[k] = v
if np.isscalar(new_dict[k]):
new_dict[k] = float(new_dict[k])
return new_dict
def respond_to_batch(model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0) -> torch.LongTensor:
"""Sample text from language model."""
input_ids = queries
for i in range(txt_len):
# Get Logits
outputs = model(input_ids)
next_token_logits = outputs[0][:, -1, :]
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_ids[:, -txt_len:]
def set_seed(seed: int) -> None:
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
Args:
seed (`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed_all(seed)
elif is_npu_available():
torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
class LengthSampler:
"""
Samples a length
"""
def __init__(self, min_value: int, max_value: int):
self.values = list(range(min_value, max_value))
def __call__(self) -> int:
return np.random.choice(self.values)
class PPODecorators(object):
optimize_device_cache = False
@classmethod
@contextmanager
def empty_device_cache(cls):
yield
if cls.optimize_device_cache:
if is_xpu_available():
gc.collect()
torch.xpu.empty_cache()
gc.collect()
elif is_npu_available():
gc.collect()
torch.npu.empty_cache()
gc.collect()
elif torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
) -> torch.Tensor:
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
warnings.warn(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
# flake8: noqa
from .base_environment import TextEnvironment, TextHistory
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import warnings
import torch
from accelerate.utils import extract_model_from_parallel
from transformers import StoppingCriteria, StoppingCriteriaList
from ..import_utils import is_rich_available
if is_rich_available():
from rich import print
from rich.text import Text
class StringStoppingCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generations in the batch are completed."""
def __init__(self, stop_strings, tokenizer):
self.stop_strings = stop_strings
self.tokenizer = tokenizer
self.first_call = True
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the stop strings."""
if self.first_call:
self.generated_tokens = [1 for _ in range(input_ids.shape[0])]
self.start_length = input_ids.shape[-1] - 1
self.first_call = False
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
done = []
for i, decoded_generation in enumerate(decoded_generations):
sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
done.append(sequence_complete)
if not sequence_complete:
self.generated_tokens[i] += 1
if all(done):
self.first_call = True
return all(done)
class TextHistory:
"""The TextHistory class keeps track of the history of an interaction between the language model and the environment."""
def __init__(self, text, tokens, system=True):
"""
Initialize TextHistory.
args:
text (`str`): The text of the first segment.
tokens (`torch.LongTensor`): The tokens of the first segment.
system (`bool`, *optional*): Whether the first segment is a system or user segment.
"""
self.system_spans = []
self.text_spans = []
self.token_spans = []
self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device)
self.text = ""
self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device)
self.completed = False
self.truncated = False
self.reward = 0.0
self.prompt_color = "black on grey85"
self.system_color = "black on cyan3"
self.model_color = "black on deep_sky_blue1"
self.reward_color = "black on plum1"
self.append_segment(text, tokens, system=system)
def append_segment(self, text, tokens, system=True):
"""
Append a new segment to the history.
args:
text (`str`): The text of the new segment.
tokens (`torch.LongTensor`): The tokens of the new segment.
system (`bool`, *optional*): Whether the new segment is a system or user segment.
"""
if len(text) == 0 or len(tokens) == 0:
raise ValueError("Can't append empty text or token list to history.")
original_text_length = len(self.text)
self.text += text
self.text_spans.append((original_text_length, len(self.text)))
self.system_spans.append(system)
original_token_length = len(self.tokens)
self.tokens = torch.cat((self.tokens, tokens))
if system:
self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens)))
else:
self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens)))
self.token_spans.append((original_token_length, len(self.tokens)))
def complete(self, truncated=False):
"""
Mark the history as completed.
"""
self.completed = True
self.truncated = truncated
@property
def last_text_segment(self):
"""
Get the last text segment.
"""
start, end = self.text_spans[-1]
return self.text[start:end]
def split_query_response_tokens(self):
"""
Split the tokens into query and response tokens.
"""
split_index = self.token_spans[0][1]
query = self.tokens[:split_index]
response = self.tokens[split_index:]
mask = self.token_masks[split_index:]
return query, response, mask
def show_text(self, show_legend=False):
"""
Print the text history.
"""
if not is_rich_available():
warnings.warn("install rich to display text")
return
text = Text(self.text)
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
for i, (start, end) in enumerate(self.text_spans[1:]):
if self.system_spans[i + 1]:
text.stylize(self.system_color, start, end)
else:
text.stylize(self.model_color, start, end)
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
print(text)
if show_legend:
self.show_colour_legend()
def show_tokens(self, tokenizer, show_legend=False):
"""
Print the history tokens.
"""
if not is_rich_available():
warnings.warn("install rich to display tokens")
return
text = Text()
prompt_end = self.token_spans[0][1]
for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)):
if i < prompt_end:
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color)
text.append(" ")
elif mask == 0:
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color)
text.append(" ")
else:
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color)
text.append(" ")
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
print(text)
if show_legend:
self.show_colour_legend()
def show_colour_legend(self):
"""
Print the colour legend.
"""
if not is_rich_available():
warnings.warn("install rich to display colour legend")
return
text = Text("\n\n(Colour Legend: ")
text.append("Prompt", style=self.prompt_color)
text.append("|")
text.append("System", style=self.system_color)
text.append("|")
text.append("Model", style=self.model_color)
text.append("|")
text.append("Reward", style=self.reward_color)
text.append(")")
print(text)
class TextEnvironment:
"""
The TextEnvironment enables interaction of a LLM with an environment using tools.
"""
def __init__(
self,
model=None,
tokenizer=None,
tools=None,
reward_fn=None,
prompt=None,
max_turns=4,
max_tool_reponse=100,
max_length=None,
generation_kwargs=None,
):
"""
Initialize TextEnvironment.
Args:
model (`PreTrainedModelWrapper`): The model to use for generation.
tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation.
tools (list): A list of tools to use for interaction.
reward_fn (function): A function that takes a string and returns a reward.
prompt (str): The base prompt to use for generation. Is prepended to the tasks.
max_turns (Optional[int]): The maximum number of turns to allow.
max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response.
max_length (Optional[int]): The maximum number of tokens to allow in an episode.
generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
"""
self.model = model
self.tokenizer = tokenizer
self.prompt = prompt
if isinstance(tools, dict):
self.tools = tools
else:
self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
self.reward_fn = reward_fn
self.max_length = max_length
self.request_token = "<request>"
self.call_token = "<call>"
self.response_token = "<response>"
self.submit_token = "<submit>"
self.max_turns = max_turns
self.max_tool_response = max_tool_reponse
if generation_kwargs is None:
self.generation_kwargs = dict()
else:
self.generation_kwargs = generation_kwargs
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
self.current_device = extract_model_from_parallel(self.model).pretrained_model.device
def run(self, queries, **rewards_kwargs):
"""
Run the environment on a list of queries.
Args:
queries (list[str]): A list of queries to run the model in the environment on.
"""
turns = 0
queries = [self.prompt + task for task in queries]
queries_tokens = [self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) for query in queries]
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
while any([not history.completed for history in histories]) and turns < self.max_turns:
histories = self.generate(histories)
histories = self.tasks_end_check(histories)
# TODO: make this parallel rather than for-loop
for i in range(len(histories)):
histories[i] = self.step(histories[i])
histories = self.tasks_end_check(histories, model_turn=False)
turns += 1
self.compute_reward(histories, **rewards_kwargs)
# convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively
queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories]))
rewards = [history.reward for history in histories]
return queries, responses, masks, rewards, histories
def step(self, history):
"""
Step the environment forward one turn.
Args:
history (`TextHistory`): The history to step forward.
"""
truncated, ended = self.task_end_check(history)
if ended:
history.complete(truncated=truncated)
if history.completed:
return history
tool, query = self.parse_tool_call(history.last_text_segment)
if tool is None or query is None:
response = f"Unknown tool call: {history.last_text_segment}"
else:
if tool not in self.tools:
response = f"Unknown tool {tool}."
try:
response = self.tools[tool](query)
except Exception as error:
response = f"Tool error: {str(error)}"
if len(response) > self.max_tool_response:
response = response[: (self.max_tool_response - 3)] + "..."
history.append_segment(
response + self.response_token,
self.tokenizer(response + self.response_token, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device),
system=True,
)
return history
def parse_tool_call(self, text):
"""
Parse request string. Expected format: <request><tool_name>query<call>
"""
result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL)
# if we can't find a <request>/<call> span we return none
if result is None:
return None, None
else:
extracted_text = result.group()
result = re.search(r"<(.*?)>", extracted_text)
# if we can't find a tool name we return none
if result is None:
return None, None
else:
tool = result.group(1)
# split off the tool name
query = ">".join(extracted_text.split(">")[1:])
return tool, query
def compute_reward(self, histories, **reward_kwargs):
"""
Compute the reward for a list of histories.
"""
rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs)
for history, reward in zip(histories, rewards):
history.reward = reward
return histories
def generate(self, histories):
"""
Generate responses for a list of histories.
"""
active_histories = [i for i, history in enumerate(histories) if not history.completed]
query_tensors = [histories[i].tokens for i in active_histories]
response_tensors = self._generate_batched(query_tensors)
response_texts = self.tokenizer.batch_decode(response_tensors)
for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors):
histories[i].append_segment(response_text, response_tensor, system=False)
return histories
def tasks_end_check(self, histories, model_turn=True):
"""
Check if the current generation sequences have finished.
"""
for history in histories:
if not history.completed:
truncated, ended = self.task_end_check(history, model_turn=model_turn)
if ended:
history.complete(truncated=truncated)
return histories
def task_end_check(self, history, model_turn=True):
"""
Check if the current generation sequence has finished.
"""
truncated = False
ended = False
if history.completed:
return truncated, ended
if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length:
truncated = True
ended = True
elif self.tokenizer.eos_token in history.text:
ended = True
elif model_turn and not ((self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) or self.submit_token in history.last_text_segment):
ended = True
elif self.submit_token in history.last_text_segment:
ended = True
return truncated, ended
def _generate_batched(
self,
query_tensors,
batch_size: int = 16,
pad_to_multiple_of: int = None,
):
"""
Generate responses for a list of query tensors.
args:
query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for.
batch_size (int): The batch size to use for generation.
pad_to_multiple_of (int): The padding length to use for generation.
"""
outputs = []
padding_side_default = self.tokenizer.padding_side
if not self.is_encoder_decoder:
self.tokenizer.padding_side = "left"
# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for i in range(0, len(query_tensors), batch_size):
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer)
self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria])
generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs)
for generation, mask, generated_tokens in zip(generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens):
if not self.is_encoder_decoder:
output = generation[(1 - mask).sum() :] # remove padding
else:
output = generation
if not self.is_encoder_decoder:
output = output[(mask).sum() :] # remove prompt
# remove chunk generated after stopping criteria in batch mode
outputs.append(output[:generated_tokens])
self.tokenizer.padding_side = padding_side_default
return outputs
# flake8: noqa
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .best_of_n_sampler import BestOfNSampler
from typing import Any, Callable, List, Optional, Union
import torch
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from ..core import set_seed
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
class BestOfNSampler(object):
def __init__(
self,
model: PreTrainedModelWrapper,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
queries_to_scores: Callable[[List[str]], List[float]],
length_sampler: Any,
sample_size: int = 4,
seed: Optional[int] = None,
n_candidates: int = 1,
generation_config: Optional[GenerationConfig] = None,
) -> None:
r"""
Initialize the sampler for best-of-n generation
Args:
model (`PreTrainedModelWrapper`):
The pretrained model to use for generation
tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`):
Tokenizer associated with the pretrained model
queries_to_scores (`Callable[[List[str]], List[float]]`):
Callable that takes a list of generated texts and returns the associated reward scores
length_sampler (`Any`):
Sampler used to sample the length of the generated text
sample_size (`int`):
Number of samples to generate for each query
seed (`int`, *optional*):
Random seed used to control generation
n_candidates (`int`):
Number of candidates to return for each query
generation_config (`GenerationConfig`, *optional*):
Generation config passed to the underlying model's `generate` method.
See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details
"""
if seed is not None:
set_seed(seed)
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}")
if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}")
self.model = model
self.tokenizer = tokenizer
self.queries_to_scores = queries_to_scores
self.length_sampler = length_sampler
self.gen_config = generation_config
self.sample_size = sample_size
self.n_candidates = n_candidates
def generate(
self,
tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]],
skip_special_tokens: bool = True,
device: Optional[Union[str, torch.device]] = None,
**generation_kwargs,
) -> List[List[str]]:
r"""
Generate the best of n samples for input queries
Args:
tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`):
represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
skip_special_tokens (`bool`):
Whether to remove the special tokens from the output
device (`str` or `torch.device`, *optional*):
The device on which the model will be loaded
**generation_kwargs (`dict`, *optional*):
Additional keyword arguments passed along to the underlying model's `generate` method.
This is used to override generation config
Returns:
List[List[str]]: A list of lists of generated texts
"""
queries = None
if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
queries = tokenized_query.unsqueeze(0)
elif isinstance(tokenized_query, List):
element_type = type(tokenized_query[0])
if element_type == int:
queries = torch.tensor(tokenized_query).unsqueeze(0)
elif element_type == torch.Tensor:
queries = [tensor.reshape((1, -1)) for tensor in tokenized_query]
else:
queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query]
result = []
for query in queries:
queries = query.repeat((self.sample_size, 1))
output = self.model.generate(
queries.to(device),
max_new_tokens=self.length_sampler(),
generation_config=self.gen_config,
**generation_kwargs,
).squeeze()
output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens)
scores = torch.tensor(self.queries_to_scores(output))
output = [output[i] for i in scores.topk(self.n_candidates).indices]
result.append(output)
return result
import logging
from typing import Callable, Literal, Optional, Union
from datasets import Dataset, Value
from transformers import AutoTokenizer
from ..trainer.utils import ConstantLengthDataset
FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""
def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
return output_texts
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
return format_dataset
def instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""
def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts
else:
converted_sample = [
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
return format_dataset
def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]
- `instruction` with [{"prompt": str, "completion": str}]
Args:
dataset (Dataset): User dataset
tokenizer (AutoTokenizer): Tokenizer used for formatting
Returns:
Callable: Formatting function if the dataset format is supported else None
"""
if isinstance(dataset, Dataset):
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
return None
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
if sys.version_info < (3, 8):
_is_python_greater_3_8 = False
else:
_is_python_greater_3_8 = True
def is_peft_available() -> bool:
return importlib.util.find_spec("peft") is not None
def is_unsloth_available() -> bool:
return importlib.util.find_spec("unsloth") is not None
def is_accelerate_greater_20_0() -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version
accelerate_version = version("accelerate")
else:
import pkg_resources
accelerate_version = pkg_resources.get_distribution("accelerate").version
return accelerate_version >= "0.20.0"
def is_transformers_greater_than(version: str) -> bool:
_transformers_version = importlib.metadata.version("transformers")
return _transformers_version > version
def is_torch_greater_2_0() -> bool:
if _is_python_greater_3_8:
from importlib.metadata import version
torch_version = version("torch")
else:
import pkg_resources
torch_version = pkg_resources.get_distribution("torch").version
return torch_version >= "2.0"
def is_diffusers_available() -> bool:
return importlib.util.find_spec("diffusers") is not None
def is_bitsandbytes_available() -> bool:
import torch
# bnb can be imported without GPU but is not usable.
return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()
def is_torchvision_available() -> bool:
return importlib.util.find_spec("torchvision") is not None
def is_rich_available() -> bool:
return importlib.util.find_spec("rich") is not None
def is_wandb_available() -> bool:
return importlib.util.find_spec("wandb") is not None
def is_xpu_available() -> bool:
if is_accelerate_greater_20_0():
import accelerate
return accelerate.utils.is_xpu_available()
else:
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
try:
import torch
return hasattr(torch, "xpu") and torch.xpu.is_available()
except RuntimeError:
return False
def is_npu_available() -> bool:
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False
import torch
import torch_npu # noqa: F401
return hasattr(torch, "npu") and torch.npu.is_available()
# flake8: noqa
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import setup_chat_format
SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
from ..import_utils import is_diffusers_available
if is_diffusers_available():
from .modeling_sd_base import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from copy import deepcopy
import torch
import torch.nn as nn
from accelerate import PartialState
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
HFValidationError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
)
from safetensors.torch import load_file as safe_load_file
from transformers import PreTrainedModel
from ..import_utils import is_npu_available, is_peft_available, is_transformers_greater_than, is_xpu_available
if is_peft_available():
from peft import (
PeftConfig,
PeftModel,
PeftModelForCausalLM,
PeftModelForSeq2SeqLM,
PromptLearningConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
if is_transformers_greater_than("4.33.0"):
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
else:
from transformers.deepspeed import is_deepspeed_zero3_enabled
LAYER_PATTERNS = [
"transformer.h.{layer}",
"model.decoder.layers.{layer}",
"gpt_neox.layers.{layer}",
"model.layers.{layer}",
]
class PreTrainedModelWrapper(nn.Module):
r"""
A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the
(`~transformers.PreTrained`) class in order to keep some attributes and methods of the
(`~transformers.PreTrainedModel`) class.
Attributes:
pretrained_model: (`transformers.PreTrainedModel`)
The model to be wrapped.
parent_class: (`transformers.PreTrainedModel`)
The parent class of the model to be wrapped.
supported_args: (`list`)
The list of arguments that are supported by the wrapper class.
"""
transformers_parent_class = None
supported_args = None
supported_modules = ("v_head",)
supported_rm_modules = ("score",)
supported_pretrained_model_architectures = (PreTrainedModel) if not is_peft_available() else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM)
def __init__(self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs):
super().__init__()
self.pretrained_model = pretrained_model
self.config = pretrained_model.config
self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation
self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False)
self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False)
self.is_sequential_parallel = False
if hasattr(pretrained_model, "gradient_checkpointing_disable"):
self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable
if hasattr(pretrained_model, "gradient_checkpointing_enable"):
self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable
self.supports_rm_adapter = supports_rm_adapter
self.rm_adapter_name = rm_adapter_name
self.policy_adapter_name = "default"
if score_module is not None:
self.score = score_module
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiates a new model from a pretrained model from `transformers`. The
pretrained model is loaded using the `from_pretrained` method of the
`transformers.PreTrainedModel` class. The arguments that are specific to the
`transformers.PreTrainedModel` class are passed along this method and filtered
out from the `kwargs` argument.
Args:
pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`):
The path to the pretrained model or its name.
*model_args (`list`, *optional*)):
Additional positional arguments passed along to the underlying model's
`from_pretrained` method.
**kwargs (`dict`, *optional*):
Additional keyword arguments passed along to the underlying model's
`from_pretrained` method. We also pre-process the kwargs to extract
the arguments that are specific to the `transformers.PreTrainedModel`
class and the arguments that are specific to trl models. The kwargs
also support `prepare_model_for_kbit_training` arguments from
`peft` library.
"""
if kwargs is not None:
peft_config = kwargs.pop("peft_config", None)
reward_adapter = kwargs.pop("reward_adapter", None)
reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter")
is_trainable = kwargs.pop("is_trainable", False)
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs)
token = pretrained_kwargs.get("token", None)
else:
peft_config = None
is_trainable = False
trl_model_args = {}
pretrained_kwargs = {}
peft_quantization_kwargs = {}
token = None
if reward_adapter is not None and not isinstance(reward_adapter, str):
raise ValueError("The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter.")
is_peft_model = False
current_device = cls._get_current_device()
if isinstance(pretrained_model_name_or_path, str):
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
else:
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False)
if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs:
# warn users
logging.warning(
"The `device_map` argument is not provided. We will override the device_map argument."
" to set the entire"
" model on the current device. If you want to set the model on multiple devices, please provide"
" a custom `device_map` argument."
)
pretrained_kwargs["device_map"] = {"": current_device}
if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig):
raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.")
# First, load the pre-trained model using the parent-class
# either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM`
if isinstance(pretrained_model_name_or_path, str):
if is_peft_available():
try:
# If there is a trained peft adapter in the hub, load its config.
remote_adapter_config = hf_hub_download(
pretrained_model_name_or_path,
"adapter_config.json",
token=token,
)
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
remote_adapter_config = None
else:
remote_adapter_config = None
local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json"))
if (local_adapter_present or remote_adapter_config is not None) and is_peft_available():
if peft_config is not None:
logging.warning("`peft_config` argument ignored since a peft config file was found in " f"{pretrained_model_name_or_path}")
# Load the trained peft adapter config
if local_adapter_present:
trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path)
else:
remote_adapter_dir = os.path.dirname(remote_adapter_config)
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir)
# Load the pretrained base model
pretrained_model = cls.transformers_parent_class.from_pretrained(trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs)
# Wrap the pretrained model with the trained peft adapter
pretrained_model = PeftModel.from_pretrained(pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable)
logging.info("Trained peft adapter loaded")
else:
pretrained_model = cls.transformers_parent_class.from_pretrained(pretrained_model_name_or_path, *model_args, **pretrained_kwargs)
if peft_config is not None:
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
pretrained_model = get_peft_model(pretrained_model, peft_config)
logging.info("peft adapter initialised")
elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures):
pretrained_model = pretrained_model_name_or_path
if peft_config is not None and isinstance(pretrained_model, PreTrainedModel):
# Initialize a new peft adapter with the given config
if is_loaded_in_8bit or is_loaded_in_4bit:
pretrained_model = prepare_model_for_kbit_training(
pretrained_model,
**peft_quantization_kwargs,
)
pretrained_model = get_peft_model(pretrained_model, peft_config)
logging.info("peft adapter initialised")
else:
raise ValueError("pretrained_model_name_or_path should be a string or a PreTrainedModel, " f"but is {type(pretrained_model_name_or_path)}")
if is_peft_available():
if isinstance(pretrained_model, PeftModel):
is_peft_model = True
# for backward compatibility
if hasattr(pretrained_model, "active_peft_config") and isinstance(pretrained_model.active_peft_config, PromptLearningConfig):
raise ValueError("PromptLearningConfig is not supported for PPO training.")
# Add reward modeling adapter if specified
if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
score_module = cls.add_and_load_reward_modeling_adapter(pretrained_model, reward_adapter, reward_adapter_name, token=token)
multi_adapter_args = {
"score_module": score_module,
"supports_rm_adapter": True,
"rm_adapter_name": reward_adapter_name,
}
else:
multi_adapter_args = {"supports_rm_adapter": False}
# Then, create the full model by instantiating the wrapper class
model = cls(pretrained_model, **multi_adapter_args, **trl_model_args)
# if resume_training, load the state_dict again - this is ok since the
# state_dict is removed from the model after loading it.
is_resuming_training = True
if isinstance(pretrained_model_name_or_path, str):
safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors")
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
is_sharded = False
use_safe = os.path.exists(safe_filename)
if not (os.path.exists(filename) or os.path.exists(safe_filename)):
# Try with `pytorch_model.bin`
filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub(
pretrained_model,
pretrained_model_name_or_path,
sharded_index_filename,
token=token,
)
# Try with safetensors
if filename is None and files_to_download is None:
safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub(
pretrained_model,
pretrained_model_name_or_path,
safe_sharded_index_filename,
token=token,
model_name="model.safetensors",
model_index_name="model.safetensors.index.json",
)
use_safe = True
else:
use_safe = False
loading_func = safe_load_file if use_safe else torch.load
load_kwargs = {} if use_safe else {"map_location": "cpu"}
if is_resuming_training:
if is_sharded:
# download each file and add it to the state_dict
state_dict = {}
for shard_file in files_to_download:
filename = hf_hub_download(
pretrained_model_name_or_path,
shard_file,
token=token,
)
state_dict.update(loading_func(filename, **load_kwargs))
else:
state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs)
else:
state_dict = pretrained_model_name_or_path.state_dict()
model.is_peft_model = is_peft_model
model.current_device = current_device
if is_resuming_training:
model.post_init(state_dict=state_dict)
return model
@classmethod
def _get_checkpoint_from_hub(
cls,
pretrained_model,
pretrained_model_name_or_path,
index_filename,
token=None,
model_name="pytorch_model.bin",
model_index_name="pytorch_model.bin.index.json",
):
files_to_download = None
filename = None
is_resuming_training = True
is_sharded = False
try:
filename = hf_hub_download(
pretrained_model_name_or_path,
model_name,
token=token,
)
# sharded
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
if os.path.exists(index_filename):
index_file_name = index_filename
else:
try:
index_file_name = hf_hub_download(
pretrained_model_name_or_path,
model_index_name,
token=token,
)
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError):
# not continue training, do not have v_head weight
is_resuming_training = False
logging.warning(f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " f"and no v_head weight is found. This IS expected if you are not resuming PPO training.")
# load json
if is_resuming_training:
with open(index_file_name, "r") as f:
index = json.load(f)
# check filename with `v_head` or any known extra module:
files_to_download = set()
for k, v in index["weight_map"].items():
if any([module in k for module in cls.supported_modules]):
files_to_download.add(v)
is_sharded = True
return filename, files_to_download, is_sharded, is_resuming_training
@classmethod
def _get_current_device(cls):
r"""
Get the current device. For GPU, we return the local process index using the `accelerate.PartialState`
object to handle corner cases when running scripts in distributed environments.
Returns:
current_device (`Union[int, str]`):
The current device.
"""
state = PartialState()
if is_xpu_available():
return f"xpu:{state.local_process_index}"
elif is_npu_available():
return f"npu:{state.local_process_index}"
else:
return state.local_process_index if torch.cuda.is_available() else "cpu"
@classmethod
def _split_kwargs(cls, kwargs):
"""
Separate the kwargs from the arguments that we support inside
`supported_args` and the ones that we don't.
"""
check_peft_kwargs = False
if is_peft_available():
from peft import prepare_model_for_kbit_training
check_peft_kwargs = True
supported_kwargs = {}
unsupported_kwargs = {}
peft_kwargs = {}
for key, value in kwargs.items():
if key in cls.supported_args:
supported_kwargs[key] = value
else:
unsupported_kwargs[key] = value
if check_peft_kwargs:
if key in prepare_model_for_kbit_training.__code__.co_varnames:
peft_kwargs[key] = value
if key in unsupported_kwargs:
unsupported_kwargs.pop(key)
return supported_kwargs, unsupported_kwargs, peft_kwargs
@classmethod
def add_and_load_reward_modeling_adapter(cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None):
r"""
Add and load a reward modeling adapter. This method can only be used if the
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the
score head in order to produce the reward.
"""
pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False)
pretrained_model.train()
filename = os.path.join(adapter_model_id, "adapter_model.bin")
safe_loading = False
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.bin",
token=token,
)
except: # noqa
filename = os.path.join(adapter_model_id, "adapter_model.safetensors")
safe_loading = True
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.safetensors",
token=token,
)
except: # noqa
raise ValueError("Could not find adapter model in the Hub, make sure you have the correct adapter model id.")
else:
local_filename = filename
else:
local_filename = filename
loading_func = safe_load_file if safe_loading else torch.load
load_kwargs = {} if safe_loading else {"map_location": "cpu"}
adapter_state_dict = loading_func(local_filename, **load_kwargs)
for score_name_candidate in cls.supported_rm_modules:
if any([score_name_candidate in name for name in adapter_state_dict.keys()]):
score_name = score_name_candidate
# we have found the correct head name and can break
break
score_dict = {}
for name, param in adapter_state_dict.items():
if score_name in name:
key_name = ".".join(name.split(".")[-1:])
score_dict[key_name] = param.to(cls._get_current_device())
num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])
score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=cls._get_current_device(),
dtype=pretrained_model.dtype,
)
score.load_state_dict(score_dict)
for param in score.parameters():
param.requires_grad = False
return score
def push_to_hub(self, *args, **kwargs):
r"""
Push the pretrained model to the hub. This method is a wrapper around
`transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation
of `transformers.PreTrainedModel.push_to_hub` for more information.
Args:
*args (`list`, *optional*):
Positional arguments passed along to the underlying model's
`push_to_hub` method.
**kwargs (`dict`, *optional*):
Keyword arguments passed along to the underlying model's
`push_to_hub` method.
"""
raise NotImplementedError
def save_pretrained(self, *args, **kwargs):
r"""
Save the pretrained model to a directory. This method is a wrapper around
`transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation
of `transformers.PreTrainedModel.save_pretrained` for more information.
Args:
*args (`list`, *optional*):
Positional arguments passed along to the underlying model's
`save_pretrained` method.
**kwargs (`dict`, *optional*):
Keyword arguments passed along to the underlying model's
`save_pretrained` method.
"""
state_dict = kwargs.get("state_dict")
if state_dict is None:
state_dict = self.state_dict()
kwargs["state_dict"] = state_dict
# if it is a peft model only save the `v_head` state_dict and
# pop the `state_dict` from the kwargs to avoid slient bugs with `peft`
if self.is_peft_model:
save_path = args[0]
save_path = os.path.join(save_path, "pytorch_model.bin")
torch.save(state_dict, save_path)
_ = kwargs.pop("state_dict", None)
return self.pretrained_model.save_pretrained(*args, **kwargs)
def state_dict(self, *args, **kwargs):
r"""
Return the state_dict of the pretrained model.
"""
raise NotImplementedError
def post_init(self, *args, **kwargs):
r"""
Post initialization method. This method is called after the model is
instantiated and loaded from a checkpoint. It can be used to perform
additional operations such as loading the state_dict.
"""
raise NotImplementedError
def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):
r"""
Computes the reward score for a given input. The method has first to enable the adapter
and then compute the reward score. After that the model disables the reward modeling
adapter and enables the default ppo adapter again.
"""
if not self.supports_rm_adapter:
raise ValueError("This model does not support reward modeling adapter.")
# enable rm adapter
self.pretrained_model.set_adapter(self.rm_adapter_name)
self.pretrained_model.eval()
with torch.no_grad():
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**kwargs,
)
last_hidden_states = base_model_output.hidden_states[-1]
scores = self.score(last_hidden_states)
self.pretrained_model.set_adapter(self.policy_adapter_name)
self.pretrained_model.eval()
return scores
def create_reference_model(model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None) -> PreTrainedModelWrapper:
"""
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
Args:
model (`PreTrainedModelWrapper`): The model to be copied.
num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen.
pattern (`str`, *optional*): The shared layers are selected with a string pattern
(e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here.
Returns
`PreTrainedModelWrapper`
"""
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoCausalLM.from_pretrained()`.")
parameter_names = [n for n, _ in model.named_parameters()]
ref_model = deepcopy(model)
# if no layers are shared, return copy of model
if num_shared_layers is None:
for param_name in parameter_names:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
return ref_model.eval()
# identify layer name pattern
if pattern is not None:
pattern = pattern.format(layer=num_shared_layers)
else:
for pattern_candidate in LAYER_PATTERNS:
pattern_candidate = pattern_candidate.format(layer=num_shared_layers)
if any([pattern_candidate in name for name in parameter_names]):
pattern = pattern_candidate
break
if pattern is None:
raise ValueError("Layer pattern could not be matched.")
# divide parameters in shared and unshared parameter lists
shared_param_list = []
unshared_param_list = []
shared_parameter = True
for name, param in model.named_parameters():
if pattern in name:
shared_parameter = False
if shared_parameter:
shared_param_list.append(name)
else:
unshared_param_list.append(name)
# create reference of the original parameter if they are shared
for param_name in shared_param_list:
param = model.get_parameter(param_name)
param.requires_grad = False
ref_param = ref_model.get_parameter(param_name) # noqa
ref_param = param # noqa
# for all other parameters just make sure they don't use gradients
for param_name in unshared_param_list:
param = ref_model.get_parameter(param_name)
param.requires_grad = False
if pattern is not None and len(unshared_param_list) == 0:
logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.")
return ref_model.eval()
# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import convert_state_dict_to_diffusers
from ..core import randn_tensor
from ..import_utils import is_peft_available
if is_peft_available():
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
@dataclass
class DDPOPipelineOutput(object):
"""
Output class for the diffusers pipeline to be finetuned with the DDPO trainer
Args:
images (`torch.Tensor`):
The generated images.
latents (`List[torch.Tensor]`):
The latents used to generate the images.
log_probs (`List[torch.Tensor]`):
The log probabilities of the latents.
"""
images: torch.Tensor
latents: torch.Tensor
log_probs: torch.Tensor
@dataclass
class DDPOSchedulerOutput(object):
"""
Output class for the diffusers scheduler to be finetuned with the DDPO trainer
Args:
latents (`torch.Tensor`):
Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)`
log_probs (`torch.Tensor`):
Log probability of the above mentioned sample. Shape: `(batch_size)`
"""
latents: torch.Tensor
log_probs: torch.Tensor
class DDPOStableDiffusionPipeline(object):
"""
Main class for the diffusers pipeline to be finetuned with the DDPO trainer
"""
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
raise NotImplementedError
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
raise NotImplementedError
@property
def unet(self):
"""
Returns the 2d U-Net model used for diffusion.
"""
raise NotImplementedError
@property
def vae(self):
"""
Returns the Variational Autoencoder model used from mapping images to and from the latent space
"""
raise NotImplementedError
@property
def tokenizer(self):
"""
Returns the tokenizer used for tokenizing text inputs
"""
raise NotImplementedError
@property
def scheduler(self):
"""
Returns the scheduler associated with the pipeline used for the diffusion process
"""
raise NotImplementedError
@property
def text_encoder(self):
"""
Returns the text encoder used for encoding text inputs
"""
raise NotImplementedError
@property
def autocast(self):
"""
Returns the autocast context manager
"""
raise NotImplementedError
def set_progress_bar_config(self, *args, **kwargs):
"""
Sets the progress bar config for the pipeline
"""
raise NotImplementedError
def save_pretrained(self, *args, **kwargs):
"""
Saves all of the model weights
"""
raise NotImplementedError
def get_trainable_layers(self, *args, **kwargs):
"""
Returns the trainable parameters of the pipeline
"""
raise NotImplementedError
def save_checkpoint(self, *args, **kwargs):
"""
Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
"""
raise NotImplementedError
def load_checkpoint(self, *args, **kwargs):
"""
Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
"""
raise NotImplementedError
def _left_broadcast(input_tensor, shape):
"""
As opposed to the default direction of broadcasting (right to left), this function broadcasts
from left to right
Args:
input_tensor (`torch.FloatTensor`): is the tensor to broadcast
shape (`Tuple[int]`): is the shape to broadcast to
"""
input_ndim = input_tensor.ndim
if input_ndim > len(shape):
raise ValueError("The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to")
return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
).to(timestep.device)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def scheduler_step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
prev_sample: Optional[torch.FloatTensor] = None,
) -> DDPOSchedulerOutput:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
Returns:
`DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample
"""
if self.num_inference_steps is None:
raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# to prevent OOB on gather
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
)
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`")
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = _get_variance(self, timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if prev_sample is not None and generator is not None:
raise ValueError("Cannot pass both generator and prev_sample. Please make sure that either `generator` or" " `prev_sample` stays `None`.")
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise
# log prob of prev_sample given prev_sample_mean and std_dev_t
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) - torch.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi)))
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob)
# 1. The output type for call is different as the logprobs are now returned
# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output
@torch.no_grad()
def pipeline_step(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Examples:
Returns:
`DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
all_latents = [latents]
all_log_probs = []
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta)
latents = scheduler_output.latents
log_prob = scheduler_output.log_probs
all_latents.append(latents)
all_log_probs.append(log_prob)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
return DDPOPipelineOutput(image, all_latents, all_log_probs)
class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True):
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name, revision=pretrained_model_revision)
self.use_lora = use_lora
self.pretrained_model = pretrained_model_name
self.pretrained_revision = pretrained_model_revision
try:
self.sd_pipeline.load_lora_weights(
pretrained_model_name,
weight_name="pytorch_lora_weights.safetensors",
revision=pretrained_model_revision,
)
self.use_lora = True
except OSError:
if use_lora:
warnings.warn("If you are aware that the pretrained model has no lora weights to it, ignore this message. " "Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder.")
self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)
self.sd_pipeline.safety_checker = None
# memory optimization
self.sd_pipeline.vae.requires_grad_(False)
self.sd_pipeline.text_encoder.requires_grad_(False)
self.sd_pipeline.unet.requires_grad_(not self.use_lora)
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
return pipeline_step(self.sd_pipeline, *args, **kwargs)
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs)
@property
def unet(self):
return self.sd_pipeline.unet
@property
def vae(self):
return self.sd_pipeline.vae
@property
def tokenizer(self):
return self.sd_pipeline.tokenizer
@property
def scheduler(self):
return self.sd_pipeline.scheduler
@property
def text_encoder(self):
return self.sd_pipeline.text_encoder
@property
def autocast(self):
return contextlib.nullcontext if self.use_lora else None
def save_pretrained(self, output_dir):
if self.use_lora:
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet))
self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
self.sd_pipeline.save_pretrained(output_dir)
def set_progress_bar_config(self, *args, **kwargs):
self.sd_pipeline.set_progress_bar_config(*args, **kwargs)
def get_trainable_layers(self):
if self.use_lora:
lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
self.sd_pipeline.unet.add_adapter(lora_config)
# To avoid accelerate unscaling problems in FP16.
for param in self.sd_pipeline.unet.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
return self.sd_pipeline.unet
else:
return self.sd_pipeline.unet
def save_checkpoint(self, models, weights, output_dir):
if len(models) != 1:
raise ValueError("Given how the trainable params were set, this should be of length 1")
if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None:
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0]))
self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict)
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
models[0].save_pretrained(os.path.join(output_dir, "unet"))
else:
raise ValueError(f"Unknown model type {type(models[0])}")
def load_checkpoint(self, models, input_dir):
if len(models) != 1:
raise ValueError("Given how the trainable params were set, this should be of length 1")
if self.use_lora:
lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict(input_dir, weight_name="pytorch_lora_weights.safetensors")
self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0])
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
models[0].register_to_config(**load_model.config)
models[0].load_state_dict(load_model.state_dict())
del load_model
else:
raise ValueError(f"Unknown model type {type(models[0])}")
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .modeling_base import PreTrainedModelWrapper
class ValueHead(nn.Module):
r"""
The ValueHead class implements a head for GPT2 that returns a scalar for each output token.
"""
def __init__(self, config, **kwargs):
super().__init__()
if not hasattr(config, "summary_dropout_prob"):
summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1)
else:
summary_dropout_prob = config.summary_dropout_prob
self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
# some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
if hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
if hasattr(config, "word_embed_proj_dim"):
hidden_size = config.word_embed_proj_dim
elif hasattr(config, "is_encoder_decoder"):
if config.is_encoder_decoder and hasattr(config, "decoder"):
if hasattr(config.decoder, "hidden_size"):
hidden_size = config.decoder.hidden_size
self.summary = nn.Linear(hidden_size, 1)
self.flatten = nn.Flatten()
def forward(self, hidden_states):
output = self.dropout(hidden_states)
# For now force upcast in fp32 if needed. Let's keep the
# output in fp32 for numerical stability.
if output.dtype != self.summary.weight.dtype:
output = output.to(self.summary.weight.dtype)
output = self.summary(output)
return output
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
r"""
An autoregressive model with a value head in addition to the language model head.
This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
`transformers.PreTrainedModel` class. The wrapper class supports classic functions
such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped
model, simply manipulate the `pretrained_model` attribute of this class.
Class attributes:
- **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
should be set to `transformers.AutoModelForCausalLM` for this class.
- **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models
in the future
- **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
by the `ValueHead` class. Currently, the supported args are:
- **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
`ValueHead` class.
- **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the
`ValueHead` if a specific initialization strategy is selected.
- **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the
`ValueHead`. Currently, the supported strategies are:
- **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default
strategy.
- **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.
"""
transformers_parent_class = AutoModelForCausalLM
lm_head_namings = ["lm_head", "embed_out"]
supported_args = (
"summary_dropout_prob",
"v_head_initializer_range",
"v_head_init_strategy",
)
def __init__(self, pretrained_model, **kwargs):
r"""
Initializes the model.
Args:
pretrained_model (`transformers.PreTrainedModel`):
The model to wrap. It should be a causal language model such as GPT2.
or any model mapped inside the `AutoModelForCausalLM` class.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class.
"""
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
raise ValueError("The model does not have a language model head, please use a model that has one.")
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
self._init_weights(**v_head_kwargs)
def _init_weights(self, **kwargs):
r"""
Initializes the weights of the value head. The default initialization strategy is random.
Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument
when calling `.from_pretrained`. Supported strategies are:
- `normal`: initializes the weights with a normal distribution.
Args:
**kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the `ValueHead` class. These arguments
can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range`
argument.
"""
initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
# random init by default
init_strategy = kwargs.pop("v_head_init_strategy", None)
if init_strategy is None:
# do nothing
pass
elif init_strategy == "normal":
self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
self.v_head.summary.bias.data.zero_()
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
**kwargs,
):
r"""
Applies a forward pass to the wrapped model and returns the logits of the value head.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past_key_values` input) to speed up sequential decoding.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""
kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
kwargs["past_key_values"] = past_key_values
if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
kwargs.pop("past_key_values")
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
last_hidden_state = base_model_output.hidden_states[-1]
lm_logits = base_model_output.logits
loss = base_model_output.loss
if last_hidden_state.device != self.v_head.summary.weight.device:
last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
value = self.v_head(last_hidden_state).squeeze(-1)
# force upcast in fp32 if logits are in half-precision
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()
return (lm_logits, loss, value)
def generate(self, *args, **kwargs):
r"""
A simple wrapper around the `generate` method of the wrapped model.
Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils)
method of the wrapped model for more information about the supported arguments.
Args:
*args (`list`, *optional*):
Positional arguments passed to the `generate` method of the wrapped model.
**kwargs (`dict`, *optional*):
Keyword arguments passed to the `generate` method of the wrapped model.
"""
return self.pretrained_model.generate(*args, **kwargs)
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
if not self.is_peft_model:
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
else:
# if it is a peft model, only save the v_head
pretrained_model_state_dict = {}
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
for k, v in v_head_state_dict.items():
pretrained_model_state_dict[f"v_head.{k}"] = v
return pretrained_model_state_dict
def push_to_hub(self, *args, **kwargs):
setattr(self.pretrained_model, "v_head", self.v_head)
return self.pretrained_model.push_to_hub(*args, **kwargs)
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
if hasattr(self.pretrained_model, "hf_device_map"):
if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values():
raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.")
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
self.v_head = self.v_head.to(first_device)
def set_device_hook(module, input, outputs):
new_output = ()
for output in outputs:
if isinstance(output, torch.Tensor):
new_output += (output.to(first_device),)
else:
new_output += (output,)
return new_output
self.register_forward_hook(set_device_hook)
self.is_sequential_parallel = True
class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
r"""
A seq2seq model with a value head in addition to the language model head.
This class inherits from `~trl.PreTrainedModelWrapper` and wraps a
`transformers.PreTrainedModel` class. The wrapper class supports classic functions
such as `from_pretrained` and `push_to_hub` and also provides some additional
functionalities such as `generate`.
Args:
pretrained_model (`transformers.PreTrainedModel`):
The model to wrap. It should be a causal language model such as GPT2.
or any model mapped inside the `AutoModelForSeq2SeqLM` class.
kwargs:
Additional keyword arguments passed along to the `ValueHead` class.
"""
transformers_parent_class = AutoModelForSeq2SeqLM
lm_head_namings = ["lm_head", "embed_out", "output_projection"]
supported_args = (
"summary_dropout_prob",
"v_head_initializer_range",
"v_head_init_strategy",
)
def __init__(self, pretrained_model, **kwargs):
super().__init__(pretrained_model, **kwargs)
v_head_kwargs, _, _ = self._split_kwargs(kwargs)
self.is_encoder_decoder = True
if not self._has_lm_head():
raise ValueError("The model does not have a language model head, please use a model that has one.")
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
self._init_weights(**v_head_kwargs)
def _has_lm_head(self):
# check module names of all modules inside `pretrained_model` to find the language model head
for name, module in self.pretrained_model.named_modules():
if any(attribute in name for attribute in self.lm_head_namings):
return True
return False
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
del state_dict
if hasattr(self.pretrained_model, "hf_device_map"):
if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values():
raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.")
# get the lm_head device
for name, module in self.pretrained_model.named_modules():
if any(attribute in name for attribute in self.lm_head_namings):
lm_head_device = module.weight.device
break
# put v_head on the same device as the lm_head to avoid issues
self.v_head = self.v_head.to(lm_head_device)
def set_device_hook(module, input, outputs):
r"""
A hook that sets the device of the output of the model to the device of the first
parameter of the model.
Args:
module (`nn.Module`):
The module to which the hook is attached.
input (`tuple`):
The input to the module.
outputs (`tuple`):
The output of the module.
"""
new_output = ()
for output in outputs:
if isinstance(output, torch.Tensor):
new_output += (output.to(lm_head_device),)
else:
new_output += (output,)
return new_output
self.register_forward_hook(set_device_hook)
self.is_sequential_parallel = True
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
if not self.is_peft_model:
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
else:
# if it is a peft model, only save the v_head
pretrained_model_state_dict = {}
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
for k, v in v_head_state_dict.items():
pretrained_model_state_dict[f"v_head.{k}"] = v
return pretrained_model_state_dict
def push_to_hub(self, *args, **kwargs):
setattr(self.pretrained_model, "v_head", self.v_head)
return self.pretrained_model.push_to_hub(*args, **kwargs)
def _init_weights(self, **kwargs):
r"""
We initialize the weights of the value head.
"""
initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
# random init by default
init_strategy = kwargs.pop("v_head_init_strategy", None)
if init_strategy is None:
# do nothing
pass
elif init_strategy == "normal":
self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
self.v_head.summary.bias.data.zero_()
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
**kwargs,
):
kwargs["past_key_values"] = past_key_values
if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
kwargs.pop("past_key_values")
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True, # We force the model to output hidden states
**kwargs,
)
last_hidden_state = base_model_output.decoder_hidden_states[-1]
lm_logits = base_model_output.logits
loss = base_model_output.loss
value = self.v_head(last_hidden_state).squeeze(-1)
# force upcast in fp32 if logits are in half-precision
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()
return (lm_logits, loss, value)
def generate(self, *args, **kwargs):
r"""
We call `generate` on the wrapped model.
"""
return self.pretrained_model.generate(*args, **kwargs)
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer
# TODO: Add Abstract Base Class if more formats are added
@dataclass
class ChatMlSpecialTokens:
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
bos_token: str = "<|im_start|>"
eos_token: str = "<|im_end|>"
pad_token: str = "<|im_end|>"
@property
def system(self):
return f"{self.bos_token}system"
@property
def user(self):
return f"{self.bos_token}user"
@property
def assistant(self):
return f"{self.bos_token}assistant"
@property
def chat_template(self):
return (
"{% for message in messages %}"
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{{{ '{self.assistant}\n' }}}}"
"{% endif %}"
)
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
resize_to_multiple_of: Optional[int] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
Returns:
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
"""
# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
chat_format = FORMAT_MAPPING[format]()
# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None)
# Make sure to update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
# flake8: noqa
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
peft_module_casting_to_bf16,
)
# isort: on
from ..import_utils import is_diffusers_available
from .base import BaseTrainer
from .ddpo_config import DDPOConfig
if is_diffusers_available():
from .ddpo_trainer import DDPOTrainer
from .dpo_trainer import DPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .model_config import ModelConfig
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_trainer import SFTTrainer
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import PyTorchModelHubMixin
class BaseTrainer(PyTorchModelHubMixin):
r"""
Base class for all trainers - this base class implements the basic functions that we
need for a trainer.
The trainer needs to have the following functions:
- step: takes in a batch of data and performs a step of training
- loss: takes in a batch of data and returns the loss
- compute_rewards: takes in a batch of data and returns the rewards
- _build_models_and_tokenizer: builds the models and tokenizer
- _build_dataset: builds the dataset
Each user is expected to implement their own trainer class that inherits from this base
if they want to use a new training algorithm.
"""
def __init__(self, config):
self.config = config
def step(self, *args):
raise NotImplementedError("Not implemented")
def loss(self, *args):
raise NotImplementedError("Not implemented")
def compute_rewards(self, *args):
raise NotImplementedError("Not implemented")
def _save_pretrained(self, save_directory):
raise NotImplementedError("Not implemented")
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional
from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available
@dataclass
class DDPOConfig:
"""
Configuration class for DDPOTrainer
"""
# common parameters
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
"""the name of this experiment (by default is the file name without the extension name)"""
run_name: Optional[str] = ""
"""Run name for wandb logging and checkpoint saving."""
seed: int = 0
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
tracker_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. wandb_project)"""
accelerator_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
logdir: str = "logs"
"""Top-level logging directory for checkpoint saving."""
# hyperparameters
num_epochs: int = 100
"""Number of epochs to train."""
save_freq: int = 1
"""Number of epochs between saving model checkpoints."""
num_checkpoint_limit: int = 5
"""Number of checkpoints to keep before overwriting old ones."""
mixed_precision: str = "fp16"
"""Mixed precision training."""
allow_tf32: bool = True
"""Allow tf32 on Ampere GPUs."""
resume_from: Optional[str] = ""
"""Resume training from a checkpoint."""
sample_num_steps: int = 50
"""Number of sampler inference steps."""
sample_eta: float = 1.0
"""Eta parameter for the DDIM sampler."""
sample_guidance_scale: float = 5.0
"""Classifier-free guidance weight."""
sample_batch_size: int = 1
"""Batch size (per GPU!) to use for sampling."""
sample_num_batches_per_epoch: int = 2
"""Number of batches to sample per epoch."""
train_batch_size: int = 1
"""Batch size (per GPU!) to use for training."""
train_use_8bit_adam: bool = False
"""Whether to use the 8bit Adam optimizer from bitsandbytes."""
train_learning_rate: float = 3e-4
"""Learning rate."""
train_adam_beta1: float = 0.9
"""Adam beta1."""
train_adam_beta2: float = 0.999
"""Adam beta2."""
train_adam_weight_decay: float = 1e-4
"""Adam weight decay."""
train_adam_epsilon: float = 1e-8
"""Adam epsilon."""
train_gradient_accumulation_steps: int = 1
"""Number of gradient accumulation steps."""
train_max_grad_norm: float = 1.0
"""Maximum gradient norm for gradient clipping."""
train_num_inner_epochs: int = 1
"""Number of inner epochs per outer epoch."""
train_cfg: bool = True
"""Whether or not to use classifier-free guidance during training."""
train_adv_clip_max: float = 5
"""Clip advantages to the range."""
train_clip_range: float = 1e-4
"""The PPO clip range."""
train_timestep_fraction: float = 1.0
"""The fraction of timesteps to train on."""
per_prompt_stat_tracking: bool = False
"""Whether to track statistics for each prompt separately."""
per_prompt_stat_tracking_buffer_size: int = 16
"""Number of reward values to store in the buffer for each prompt."""
per_prompt_stat_tracking_min_count: int = 16
"""The minimum number of reward values to store in the buffer."""
async_reward_computation: bool = False
"""Whether to compute rewards asynchronously."""
max_workers: int = 2
"""The maximum number of workers to use for async reward computation."""
negative_prompts: Optional[str] = ""
"""Comma-separated list of prompts to use as negative examples."""
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.log_with not in ["wandb", "tensorboard"]:
warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."))
if self.log_with == "wandb" and not is_torchvision_available():
warnings.warn("Wandb image logging requires torchvision to be installed")
if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment