"examples/community/clip_guided_stable_diffusion_img2img.py" did not exist on "bd8df2da89d99f630e5aa2ddb8f8cb45456561f1"
Commit f92481f0 authored by chenych's avatar chenych
Browse files

First commit.

parent 7121d0b0
Pipeline #2435 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
HALF_LIST = [16, "16", "fp16", "float16"]
FLOAT_LIST = [32, "32", "fp32", "float32"]
BFLOAT_LIST = ["bf16", "bfloat16"]
class PrecisionType:
"""Type of precision used.
>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""
HALF = "16"
FLOAT = "32"
FULL = "64"
BFLOAT = "bf16"
MIXED = "mixed"
@staticmethod
def is_fp16(precision):
return precision in HALF_LIST
@staticmethod
def is_fp32(precision):
return precision in FLOAT_LIST
@staticmethod
def is_bf16(precision):
return precision in BFLOAT_LIST
@staticmethod
def to_dtype(precision):
if precision in HALF_LIST:
return torch.float16
elif precision in FLOAT_LIST:
return torch.float32
elif precision in BFLOAT_LIST:
return torch.bfloat16
else:
raise RuntimeError(f"unexpected precision: {precision}")
@staticmethod
def to_str(precision):
if precision == torch.float16:
return "fp16"
elif precision == torch.float32:
return "fp32"
elif precision == torch.bfloat16:
return "bf16"
else:
raise RuntimeError(f"unexpected precision: {precision}")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contain small torch utilities
"""
import math
from typing import List, Literal, Union
import torch
import torch.distributed
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from transformers import PreTrainedTokenizer
try:
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True
except ImportError:
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False
def logprobs_from_logits(logits, labels):
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
logits = logits.reshape(-1, last_dim)
labels = labels.reshape(-1)
output = logprobs_from_logits_flash_attn(logits, labels)
output = output.view(*batch_dim)
else:
output = logprobs_from_logits_v2(logits, labels)
return output
def logprobs_from_logits_flash_attn(logits, labels):
output = cross_entropy_loss(logits, labels)
assert isinstance(output, tuple), (
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
)
return -output[0]
def logprobs_from_logits_v2(logits: torch.FloatTensor, labels):
"""
A memory efficient implementation of logprobs_from_logits
"""
if logits.dtype in [torch.float32, torch.float64]:
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
logprobs_labels = []
for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
row_logprobs = F.log_softmax(row_logits, dim=-1)
row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
logprobs_labels.append(row_logprobs_labels)
logprobs_labels = torch.stack(logprobs_labels)
return logprobs_labels
def clip_by_value(x, tensor_min, tensor_max):
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
"""
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
return clipped
def entropy_from_logits(logits: torch.Tensor):
"""Calculate entropy from logits."""
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy
def masked_mean(values, mask, axis=None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError("At least one element in the mask has to be 1.")
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
if mask_sum == 1:
raise ValueError("The sum of the mask is one, which can cause a division by zero.")
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def masked_whiten(values, mask, shift_mean=True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def get_eos_mask(response_ids: torch.Tensor, eos_token: Union[int, List[int]] = 2, dtype=torch.int64):
"""
end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_ids: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
"""
if isinstance(eos_token, int):
eos_token = [eos_token]
eos_mask = torch.zeros_like(response_ids, dtype=torch.bool)
for token in eos_token:
eos_mask |= response_ids.eq(token)
eos_mask = eos_mask.long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
def pad_2d_list_to_length(response, pad_token_id, max_length=None) -> torch.Tensor:
"""
pad a 2D list (e.g. responses, logprobs) to a 2D tensor.
"""
response_length = max(len(sub_list) for sub_list in response)
if max_length is not None and max_length > response_length:
target_length = max_length
else:
target_length = response_length
padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response]
tensor = torch.tensor(padded_response)
return tensor
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
"""
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
input shape: [bs, seq_length]
output shape: [bs, max_seq_length]
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
"""
if tensors.shape[-1] >= max_seq_len:
return tensors
pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1])
return F.pad(tensors, pad_tuple, "constant", pad_token_id)
def tokenize_and_postprocess_data(
prompt: str,
tokenizer: PreTrainedTokenizer,
max_length: int,
pad_token_id: int,
left_pad: bool = True,
truncation: Literal["left", "right", "error"] = "error",
):
"""
input_data is the output from tokenizer.
"""
assert truncation in ["left", "right", "error"]
input_data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = input_data["input_ids"][0]
attention_mask = input_data["attention_mask"][0]
sequence_length = len(input_ids)
if sequence_length < max_length:
input_ids = pad_sequence_to_length(
input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad
)
attention_mask = pad_sequence_to_length(
attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad
)
elif sequence_length > max_length:
if truncation == "left":
# actually, left truncation may not be reasonable
input_ids = input_ids[-max_length:]
attention_mask = attention_mask[-max_length:]
elif truncation == "right":
input_ids = input_ids[:max_length]
attention_mask = attention_mask[:max_length]
elif truncation == "error":
raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}")
else:
raise NotImplementedError(f"Unknown truncation method {truncation}")
return input_ids, attention_mask
def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor):
"""Remove the pad token.
Args:
input_ids shape: [bs, seq_length]
attention_mask shape: [bs, seq_length]
Returns:
no_padding_batch(List[List[int]]): contains the rmpad token ids per query.
"""
no_padding_batch = []
for ids, mask in zip(input_ids, attention_mask):
no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist())
return no_padding_batch
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.0,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0
coef = (1 - min_lr_ratio) * 0.5
intercept = (1 + min_lr_ratio) * 0.5
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
x = math.cos(math.pi * float(num_cycles) * 2.0 * progress)
return max(0.0, x * coef + intercept)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_constant_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
last_epoch: int = -1,
):
def lr_lambda(current_step):
return min(1, float(current_step) / float(max(1, num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import os
from typing import List, Union
from verl.utils.logger.aggregate_logger import LocalLogger
class Tracking:
supported_backend = ["wandb", "mlflow", "swanlab", "console"]
def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = "console", config=None):
if isinstance(default_backend, str):
default_backend = [default_backend]
for backend in default_backend:
assert backend in self.supported_backend, f"{backend} is not supported"
self.logger = {}
if "wandb" in default_backend:
import wandb # type: ignore
wandb.init(project=project_name, name=experiment_name, config=config)
self.logger["wandb"] = wandb
if "mlflow" in default_backend:
import mlflow # type: ignore
mlflow.start_run(run_name=experiment_name)
mlflow.log_params(config)
self.logger["mlflow"] = _MlflowLoggingAdapter()
if "swanlab" in default_backend:
import swanlab # type: ignore
SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None)
SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog")
SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud")
if SWANLAB_API_KEY:
swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten
swanlab.init(
project=project_name,
experiment_name=experiment_name,
config=config,
logdir=SWANLAB_LOG_DIR,
mode=SWANLAB_MODE,
)
self.logger["swanlab"] = swanlab
if "console" in default_backend:
self.console_logger = LocalLogger(print_to_console=True)
self.logger["console"] = self.console_logger
def log(self, data, step, backend=None):
for default_backend, logger_instance in self.logger.items():
if backend is None or default_backend in backend:
logger_instance.log(data=data, step=step)
def __del__(self):
if "wandb" in self.logger:
self.logger["wandb"].finish(exit_code=0)
if "swanlab" in self.logger:
self.logger["swanlab"].finish()
class _MlflowLoggingAdapter:
def log(self, data, step):
import mlflow # type: ignore
mlflow.log_metrics(metrics=data, step=step)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for DeepSpeed Ulysses Sequence Parallelism.
DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509
Inspired from: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py
"""
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
"""
Set ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
"""
Get ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
"""
Get ulysses sequence parallel world size.
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_world_size(group) if group else 1
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
"""
Get ulysses sequence parallel rank.
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_rank(group) if group else 0
def gather_seq_scatter_heads(
x: Tensor,
seq_dim: int,
head_dim: int,
unpadded_dim_size: int = 0,
group: ProcessGroup = None,
) -> Tensor:
"""
A func to sync embedding input with alltoall in sequence parallel
gather sequence dimension and scatter head dim:
e.g. seq_dim: 1, head_dim: 2
[bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...]
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
if not group:
return x
sp_world = get_ulysses_sequence_parallel_world_size(group)
x = SeqAllToAll.apply(group, x, head_dim, seq_dim)
if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
padding_size = x.size(seq_dim) - unpadded_dim_size
x = _unpad_tensor(x, seq_dim, padding_size)
return x
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
"""
A func to sync attention result with alltoall in sequence parallel
gather head dimension and scatter seq dim:
e.g. seq_dim: 1, head_dim: 2
[bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...]
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
if not group:
return x
dim_size = x.size(seq_dim)
sp_world = get_ulysses_sequence_parallel_world_size(group)
if dim_size % sp_world != 0:
padding_size = sp_world - (dim_size % sp_world)
x = _pad_tensor(x, seq_dim, padding_size)
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
shape = list(x.shape)
shape[dim] = padding_size
pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=dim)
def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
slc = [slice(None)] * len(x.shape)
slc[dim] = slice(0, -padding_size)
return x[slc]
def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor:
group = get_ulysses_sequence_parallel_group() if group is None else group
sp_world_size = dist.get_world_size(group)
sp_rank = get_ulysses_sequence_parallel_rank()
dim_size = x.size(dim)
# pad before slice
if padding and dim_size % sp_world_size:
padding_size = sp_world_size - (dim_size % sp_world_size)
x = _pad_tensor(x, dim, padding_size)
# slice the input tensor
parts = x.size(dim) // sp_world_size
slc = [slice(None)] * len(x.shape)
slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts)
return x[slc].contiguous()
def all_to_all_tensor(
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
group: Optional[dist.ProcessGroup] = None,
async_op: bool = False,
):
group = get_ulysses_sequence_parallel_group() if group is None else group
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
if async_op:
def wait():
comm.wait()
return torch.cat(output_list, dim=gather_dim).contiguous()
return wait
return torch.cat(output_list, dim=gather_dim).contiguous()
def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False):
group = get_ulysses_sequence_parallel_group() if group is None else group
sp_world_size = dist.get_world_size(group=group)
output_shape = list(local_tensor.shape)
output_shape[0] = output_shape[0] * sp_world_size
output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device)
dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op)
return output
class SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
async_op: bool = False,
) -> Tensor:
ctx.group = group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.async_op = async_op
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
if ctx.async_op:
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
else:
input_t = grad_output[0]
return (
None,
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
None,
None,
None,
None,
)
class Gather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
local_tensor: Tensor,
gather_dim: int,
grad_scaler: bool = True,
async_op=False,
) -> Tensor:
ctx.group = group
ctx.gather_dim = gather_dim
ctx.grad_scaler = grad_scaler
ctx.async_op = async_op
sp_world_size = dist.get_world_size(group=group)
ctx.sp_world_size = sp_world_size
sp_rank = dist.get_rank(group=group)
ctx.sp_rank = sp_rank
local_shape = list(local_tensor.size())
split_size = local_shape[0]
part_size = local_shape[gather_dim] # store original size
ctx.part_size = part_size
output = all_gather_tensor(local_tensor, group, async_op)
return torch.cat(output.split(split_size, dim=0), dim=gather_dim)
@staticmethod
def backward(ctx: Any, grad_output: Tensor) -> Any:
if ctx.grad_scaler:
grad_output = grad_output * ctx.sp_world_size
return (
None,
grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(),
None,
None,
None,
None,
)
def gather_outpus_and_unpad(
x: Tensor,
gather_dim: int,
unpad_dim: int = None,
padding_size: int = 0,
grad_scaler: bool = True,
group: Optional[dist.ProcessGroup] = None,
):
group = get_ulysses_sequence_parallel_group() if group is None else group
# sp_size = get_ulysses_sequence_parallel_world_size()
if group is None:
return x
x = Gather.apply(group, x, gather_dim, grad_scaler)
if unpad_dim is not None:
assert isinstance(padding_size, int), "padding size is not given or is not an integer"
if padding_size == 0:
return x
x = _unpad_tensor(x, unpad_dim, padding_size)
return x
def ulysses_pad_and_slice_inputs(
input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1
):
"""
Pad and slice input_ids to be divisible by sp_size
Pad position_ids to be divisible by sp_size.
Note both input_ids_rmpad and position_ids_rmpad will be padded,
but only input_ids will be sliced.
The is the utility of pre-forward for ulysses sequence parallelism
Args:
input_ids_rmpad: shape of [bsz, seqlen]
position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1
sp_size (int): ulysses sequence parallelism size
Returns:
torch.Tensor: padded and sliced input_ids
torch.Tensor: padded and sliced position_ids
int: pad size
"""
if position_ids_rmpad is not None:
assert position_ids_rmpad.size(0) == 1
assert input_ids_rmpad.size(1) == position_ids_rmpad.size(1)
if sp_size <= 1:
return input_ids_rmpad, position_ids_rmpad, 0
_, total_seq_len = input_ids_rmpad.shape
pad_size = (sp_size - total_seq_len % sp_size) % sp_size
if pad_size > 0:
input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0)
if position_ids_rmpad is not None:
pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0)
position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1)
# we don't need to slice position ids
input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False)
return input_ids_rmpad, position_ids_rmpad, pad_size
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPOActor
from .config import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
from .dp_actor import DataParallelPPOActor
__all__ = [
"ActorConfig",
"BasePPOActor",
"DataParallelPPOActor",
"FSDPConfig",
"ModelConfig",
"OptimConfig",
"RefConfig",
]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The base class for Actor
"""
from abc import ABC, abstractmethod
from typing import Any, Dict
import torch
from verl import DataProto
from verl.workers.actor.config import ActorConfig
__all__ = ["BasePPOActor"]
class BasePPOActor(ABC):
def __init__(self, config: ActorConfig):
"""The base class for PPO actor
Args:
config (ActorConfig): a config passed to the PPOActor.
"""
self.config = config
@abstractmethod
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute logits given a batch of data.
Args:
data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```,
```attention_mask``` and ```position_ids```.
Returns:
DataProto: a DataProto containing the key ```log_probs```
"""
pass
@abstractmethod
def update_policy(self, data: DataProto) -> Dict[str, Any]:
"""Update the policy with an iterator of DataProto
Args:
data (DataProto): an iterator over the DataProto that returns by
```make_minibatch_iterator```
Returns:
Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model
such as ```loss```, ```grad_norm```, etc,.
"""
pass
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor config
"""
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
@dataclass
class ModelConfig:
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
override_config: Dict[str, Any] = field(default_factory=dict)
enable_gradient_checkpointing: bool = True
trust_remote_code: bool = True
def post_init(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@dataclass
class OptimConfig:
lr: float = 1e-6
betas: Tuple[float, float] = (0.9, 0.999)
weight_decay: float = 1e-2
lr_warmup_steps_ratio: float = 0.0
min_lr_ratio: Optional[float] = None
warmup_style: str = "constant"
"""auto keys"""
training_steps: int = field(default=-1, init=False)
@dataclass
class FSDPConfig:
enable_full_shard: bool = True
param_offload: bool = False
optimizer_offload: bool = False
torch_dtype: Optional[str] = None
mp_param_dtype: str = "bf16"
mp_reduce_dtype: str = "fp32"
mp_buffer_dtype: str = "fp32"
@dataclass
class OffloadConfig:
param_offload: bool = False
optimizer_offload: bool = False
@dataclass
class ActorConfig:
strategy: str = "fsdp"
global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = field(default=-1, init=False)
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
max_grad_norm: float = 1.0
clip_ratio: float = 0.2
entropy_coeff: float = 1e-3
use_kl_loss: bool = True
kl_loss_coef: float = 1e-3
kl_loss_type: str = "low_var_kl"
ppo_epochs: int = 1
padding_free: bool = False
ulysses_sequence_parallel_size: int = 1
model: ModelConfig = field(default_factory=ModelConfig)
optim: OptimConfig = field(default_factory=OptimConfig)
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig)
"""auto keys"""
global_batch_size_per_device: int = field(default=-1, init=False)
def post_init(self):
if self.ppo_epochs != 1:
raise NotImplementedError
@dataclass
class RefConfig:
strategy: str = "fsdp"
offload: OffloadConfig = field(default_factory=OffloadConfig)
"""auto keys"""
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
padding_free: bool = field(default=False, init=False)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement Actor
"""
import os
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from tqdm import tqdm
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer import core_algos
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, masked_mean
from verl.workers.actor.base import BasePPOActor
from verl.workers.actor.config import ActorConfig
__all__ = ["DataParallelPPOActor"]
class DataParallelPPOActor(BasePPOActor):
def __init__(
self,
config: ActorConfig,
actor_module: nn.Module,
actor_optimizer: Optional[torch.optim.Optimizer] = None,
):
"""
When optimizer is None, it is Reference Policy
"""
super().__init__(config)
self.rank = int(os.getenv("RANK", "0"))
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
def _forward_micro_batch(
self, micro_batch: Dict[str, torch.Tensor], temperature: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
input_ids = micro_batch["input_ids"]
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
responses = micro_batch["responses"]
response_length = responses.size(-1)
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs = {}
if "pixel_values" in micro_batch:
vision_inputs["pixel_values"] = torch.cat(micro_batch["pixel_values"], dim=0)
vision_inputs["image_grid_thw"] = torch.cat(micro_batch["image_grid_thw"], dim=0)
if self.config.padding_free:
# TODO (yaowei): preprocess data for padding_free and ulysses
raise NotImplementedError
else:
output = self.actor_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**vision_inputs,
use_cache=False,
)
logits: torch.Tensor = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size)
log_probs = logprobs_from_logits(logits, responses) # (bsz, response_length)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
return entropy, log_probs
def _optimizer_step(self) -> torch.Tensor:
if isinstance(self.actor_module, FSDP):
grad_norm = self.actor_module.clip_grad_norm_(self.config.max_grad_norm)
else:
grad_norm = nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.max_grad_norm)
self.actor_optimizer.step()
return grad_norm
@torch.no_grad()
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
torch.Tensor: the log_prob tensor
"""
self.actor_module.eval()
temperature = data.meta_info["temperature"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
else:
non_tensor_select_keys = None
micro_batches = data.select(select_keys, non_tensor_select_keys).split(
self.config.micro_batch_size_per_device_for_experience
)
log_probs_lst = []
for micro_batch in tqdm(micro_batches, desc="Compute log probs", disable=(self.rank != 0)):
micro_batch.to("cuda")
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
_, log_probs = self._forward_micro_batch(model_inputs, temperature=temperature)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
return log_probs
def update_policy(self, data: DataProto) -> Dict[str, Any]:
self.actor_module.train()
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
else:
non_tensor_select_keys = None
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list)
n = len(mini_batches)
for i, mini_batch in enumerate(mini_batches):
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
self.actor_optimizer.zero_grad()
for micro_batch in tqdm(micro_batches, desc=f"Update policy [{i + 1}/{n}]", disable=(self.rank != 0)):
micro_batch.to("cuda")
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
response_length = responses.size(1)
attention_mask = model_inputs["attention_mask"]
response_mask = attention_mask[:, -response_length:]
old_log_prob = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]
clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff
# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(model_inputs, temperature=temperature)
pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio,
)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)
# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
if self.config.use_kl_loss:
ref_log_prob = model_inputs["ref_log_prob"]
# compute kl loss
kld = core_algos.kl_penalty(
logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type,
)
kl_loss = masked_mean(kld, response_mask)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_loss_coef
loss = policy_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"actor/entropy_loss": entropy_loss.detach().item(),
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"actor/grad_norm": grad_norm.detach().item()})
self.actor_optimizer.zero_grad()
return metrics
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ActorRolloutRef config
"""
from dataclasses import dataclass, field
from verl.workers.actor import ActorConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig
from verl.workers.critic import CriticConfig
from verl.workers.reward import RewardConfig
from verl.workers.rollout import RolloutConfig
__all__ = [
"ActorConfig",
"CriticConfig",
"FSDPConfig",
"ModelConfig",
"OptimConfig",
"RefConfig",
"RewardConfig",
"RolloutConfig",
"WorkerConfig",
]
@dataclass
class WorkerConfig:
hybrid_engine: bool = True
actor: ActorConfig = field(default_factory=ActorConfig)
critic: CriticConfig = field(default_factory=CriticConfig)
ref: RefConfig = field(default_factory=RefConfig)
reward: RewardConfig = field(default_factory=RewardConfig)
rollout: RolloutConfig = field(default_factory=RolloutConfig)
def post_init(self):
self.ref.padding_free = self.actor.padding_free
self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import BasePPOCritic
from .config import CriticConfig, ModelConfig
from .dp_critic import DataParallelPPOCritic
__all__ = ["BasePPOCritic", "CriticConfig", "DataParallelPPOCritic", "ModelConfig"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for Critic
"""
from abc import ABC, abstractmethod
from typing import Any, Dict
import torch
from verl import DataProto
from verl.workers.critic.config import CriticConfig
__all__ = ["BasePPOCritic"]
class BasePPOCritic(ABC):
def __init__(self, config: CriticConfig):
self.config = config
@abstractmethod
def compute_values(self, data: DataProto) -> torch.Tensor:
"""Compute values"""
pass
@abstractmethod
def update_critic(self, data: DataProto) -> Dict[str, Any]:
"""Update the critic"""
pass
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Critic config
"""
from dataclasses import dataclass, field
from verl.workers.actor.config import FSDPConfig, ModelConfig, OffloadConfig, OptimConfig
@dataclass
class CriticConfig:
strategy: str = "fsdp"
global_batch_size: int = 256
micro_batch_size_per_device_for_update: int = field(default=-1, init=False)
micro_batch_size_per_device_for_experience: int = field(default=-1, init=False)
max_grad_norm: float = 1.0
cliprange_value: float = 0.5
padding_free: bool = False
ulysses_sequence_parallel_size: int = 1
model: ModelConfig = field(default_factory=ModelConfig)
optim: OptimConfig = field(default_factory=OptimConfig)
fsdp: FSDPConfig = field(default_factory=FSDPConfig)
offload: OffloadConfig = field(default_factory=OffloadConfig)
"""auto keys"""
global_batch_size_per_device: int = field(default=-1, init=False)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement Critic
"""
import os
from collections import defaultdict
from typing import Any, Dict
import torch
import torch.distributed
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from tqdm import tqdm
from verl import DataProto
from verl.trainer import core_algos
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.workers.critic.base import BasePPOCritic
from verl.workers.critic.config import CriticConfig
__all__ = ["DataParallelPPOCritic"]
class DataParallelPPOCritic(BasePPOCritic):
def __init__(self, config: CriticConfig, critic_module: nn.Module, critic_optimizer: torch.optim.Optimizer):
super().__init__(config)
self.rank = int(os.getenv("RANK", "0"))
self.critic_module = critic_module
self.critic_optimizer = critic_optimizer
def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor:
input_ids = micro_batch["input_ids"]
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
responses = micro_batch["responses"]
response_length = responses.size(-1)
if position_ids.dim() == 3: # qwen2vl mrope
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs = {}
if "pixel_values" in micro_batch:
vision_inputs["pixel_values"] = torch.cat(micro_batch["pixel_values"], dim=0)
vision_inputs["image_grid_thw"] = torch.cat(micro_batch["image_grid_thw"], dim=0)
if self.config.padding_free:
# TODO (yaowei): preprocess data for padding_free and ulysses
raise NotImplementedError
else:
output = self.critic_module(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**vision_inputs,
use_cache=False,
)
values: torch.Tensor = output.logits
values = values[:, -response_length - 1 : -1].squeeze(-1) # (bsz, response_length, vocab_size)
return values
def _optimizer_step(self) -> torch.Tensor:
if isinstance(self.critic_module, FSDP):
grad_norm = self.critic_module.clip_grad_norm_(self.config.max_grad_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.critic_module.parameters(), max_norm=self.config.max_grad_norm
)
self.critic_optimizer.step()
return grad_norm
@torch.no_grad()
def compute_values(self, data: DataProto) -> torch.Tensor:
self.critic_module.eval()
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
else:
non_tensor_select_keys = None
micro_batches = data.select(select_keys, non_tensor_select_keys).split(
self.config.micro_batch_size_per_device_for_experience
)
values_lst = []
for micro_batch in tqdm(micro_batches, "Compute values", disable=(self.rank != 0)):
micro_batch.to("cuda")
values = self._forward_micro_batch(micro_batch)
values_lst.append(values)
values = torch.concat(values_lst, dim=0)
responses = data.batch["responses"]
attention_mask = data.batch["attention_mask"]
response_length = responses.size(1)
values = values * attention_mask[:, -response_length - 1 : -1]
return values
def update_critic(self, data: DataProto) -> Dict[str, Any]:
self.critic_module.train()
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
if "pixel_values" in data.non_tensor_batch.keys():
non_tensor_select_keys = ["pixel_values", "image_grid_thw"]
else:
non_tensor_select_keys = None
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device)
metrics = defaultdict(list)
n = len(mini_batches)
for i, mini_batch in enumerate(mini_batches):
gradient_accumulation = (
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
)
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
self.critic_optimizer.zero_grad()
for micro_batch in tqdm(micro_batches, desc=f"Update critic [{i + 1}/{n}]", disable=(self.rank != 0)):
micro_batch.to("cuda")
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
responses = model_inputs["responses"]
attention_mask = model_inputs["attention_mask"]
values = model_inputs["values"]
returns = model_inputs["returns"]
response_length = responses.size(1)
eos_mask = attention_mask[:, -response_length - 1 : -1]
vpreds = self._forward_micro_batch(data)
vf_loss, vf_clipfrac = core_algos.compute_value_loss(
vpreds=vpreds,
values=values,
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value,
)
loss = vf_loss / gradient_accumulation
loss.backward()
batch_metrics = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, eos_mask).detach().item(),
}
append_to_dict(metrics, batch_metrics)
grad_norm = self._optimizer_step()
append_to_dict(metrics, {"critic/grad_norm": grad_norm.detach().item()})
self.critic_optimizer.zero_grad()
return metrics
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
from typing import Literal
import torch
import torch.distributed as dist
from accelerate import init_empty_weights
from codetiming import Timer
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
GenerationConfig,
PreTrainedModel,
)
from transformers.modeling_utils import no_init_weights
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import get_tokenizer, get_processor
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_fn,
load_fsdp_model,
load_fsdp_optimizer,
offload_fsdp_model,
offload_fsdp_optimizer,
)
from verl.utils.model_utils import print_model_size
from verl.utils.performance import log_gpu_memory_usage
from verl.utils.torch_dtypes import PrecisionType
from verl.utils.torch_functional import get_constant_schedule_with_warmup
from verl.workers.actor import DataParallelPPOActor
from verl.workers.config import FSDPConfig, ModelConfig, OptimConfig, WorkerConfig
from verl.workers.critic import DataParallelPPOCritic
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.sharding_manager import FSDPVLLMShardingManager
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
class FSDPWorker(Worker):
def __init__(
self,
config: WorkerConfig,
role: Literal["actor", "critic", "rollout", "ref", "actor_rollout", "actor_rollout_ref"],
):
super().__init__()
self.config = config
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# build device mesh for FSDP
# TODO: support FSDP hybrid shard for larger model
world_size = dist.get_world_size()
self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
# build device mesh for Ulysses Sequence Parallel
self.ulysses_sequence_parallel_size = self.config.actor.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
"cuda",
mesh_shape=(world_size // self.ulysses_sequence_parallel_size, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "sp"],
)
else:
self.ulysses_device_mesh = None
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.role = role
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_critic = self.role == "critic"
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._use_param_offload = False
self._use_optimizer_offload = False
if self._is_actor:
self._use_param_offload = self.config.actor.offload.param_offload
self._use_optimizer_offload = self.config.actor.offload.optimizer_offload
elif self._is_critic:
self._use_param_offload = self.config.critic.offload.param_offload
self._use_optimizer_offload = self.config.critic.offload.optimizer_offload
elif self._is_ref:
# NOTE: it seems that manual offload is slowly than FSDP offload
self._use_param_offload = self.config.ref.offload.param_offload
# normalize config
if self._is_actor:
self.config.actor.global_batch_size *= self.config.rollout.n
self.config.actor.global_batch_size_per_device = (
self.config.actor.global_batch_size // self.device_mesh.shape[0] * self.ulysses_sequence_parallel_size
)
assert (
self.config.actor.global_batch_size_per_device
% self.config.actor.micro_batch_size_per_device_for_update
== 0
)
elif self._is_critic:
self.config.critic.global_batch_size *= self.config.rollout.n
self.config.critic.global_batch_size_per_device = (
self.config.critic.global_batch_size // self.device_mesh.shape[0] * self.ulysses_sequence_parallel_size
)
assert (
self.config.critic.global_batch_size_per_device
% self.config.critic.micro_batch_size_per_device_for_update
== 0
)
def _build_model_optimizer(
self,
model_config: ModelConfig,
fsdp_config: FSDPConfig,
optim_config: OptimConfig,
padding_free: bool = False,
) -> None:
self.tokenizer = get_tokenizer(model_config.tokenizer_path, trust_remote_code=model_config.trust_remote_code)
self.processor = get_processor(model_config.tokenizer_path)
self.model_config = AutoConfig.from_pretrained(
model_config.model_path,
trust_remote_code=model_config.trust_remote_code,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
**model_config.override_config,
)
try:
self.generation_config = GenerationConfig.from_pretrained(model_config.model_path)
except Exception:
self.generation_config = GenerationConfig.from_model_config(self.model_config)
self.print_rank0(f"Model config: {self.model_config}")
if padding_free:
raise NotImplementedError("Padding free is not implemented yet.")
if fsdp_config.torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(fsdp_config.torch_dtype)
if self._is_critic:
auto_class = AutoModelForTokenClassification
elif type(self.model_config) in AutoModelForVision2Seq._model_mapping.keys():
auto_class = AutoModelForVision2Seq
else:
auto_class = AutoModelForCausalLM
if self.rank == 0:
model = auto_class.from_pretrained(
model_config.model_path,
config=self.model_config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=model_config.trust_remote_code,
)
else:
with no_init_weights(), init_empty_weights():
model = auto_class.from_config(
self.model_config,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
trust_remote_code=model_config.trust_remote_code,
)
assert isinstance(model, PreTrainedModel) # lint
model.tie_weights() # avoid hanging
model = model.to(torch_dtype)
if model_config.enable_gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
dist.barrier()
if self.rank == 0:
print_model_size(model)
log_gpu_memory_usage("After init from huggingface model")
mixed_precision = MixedPrecision(
param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype),
reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype),
buffer_dtype=PrecisionType.to_dtype(fsdp_config.mp_buffer_dtype),
)
auto_wrap_policy = get_fsdp_wrap_policy(model)
if fsdp_config.enable_full_shard:
sharding_strategy = ShardingStrategy.FULL_SHARD
else:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
if fsdp_config.param_offload or fsdp_config.optimizer_offload:
cpu_offload = CPUOffload(offload_params=fsdp_config.param_offload)
else:
cpu_offload = None
if self.rank == 0:
print(f"FSDP wrap policy: {auto_wrap_policy}.")
self.fsdp_module = FSDP(
model,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision,
param_init_fn=get_init_fn(model, device="cuda") if self.rank != 0 else None,
device_id=torch.cuda.current_device(),
sync_module_states=True,
forward_prefetch=False,
use_orig_params=False,
device_mesh=self.device_mesh,
)
log_gpu_memory_usage("After Actor FSDP init")
if self._is_actor or self._is_critic:
self.optimizer = torch.optim.AdamW(
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
)
num_warmup_steps = int(optim_config.lr_warmup_steps_ratio * optim_config.training_steps)
self.lr_scheduler = get_constant_schedule_with_warmup(
optimizer=self.optimizer, num_warmup_steps=num_warmup_steps
)
else:
self.optimizer, self.lr_scheduler = None, None
log_gpu_memory_usage("After actor optimizer init")
def _build_rollout(self) -> None:
# TODO(sgm): support FSDP hybrid shard for larger model
tp_size = self.config.rollout.tensor_parallel_size
dp_size = self.world_size // tp_size
assert self.world_size % tp_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by tp_size: {tp_size}"
)
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=["dp", "tp"])
log_gpu_memory_usage("Before building vllm rollout")
self.rollout = vLLMRollout(
model_path=self.config.actor.model.model_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
)
log_gpu_memory_usage("After building vllm rollout")
self.rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.fsdp_module,
inference_engine=self.rollout.inference_engine,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage("After building sharding manager")
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
if self._is_critic:
model_config = self.config.critic.model
fsdp_config = self.config.critic.fsdp
optim_config = self.config.critic.optim
padding_free = self.config.critic.padding_free
else:
model_config = self.config.actor.model
fsdp_config = self.config.actor.fsdp
optim_config = self.config.actor.optim
padding_free = self.config.actor.padding_free
if self._is_actor or self._is_critic or self._is_ref:
self._build_model_optimizer(
model_config=model_config,
fsdp_config=fsdp_config,
optim_config=optim_config,
padding_free=padding_free,
)
# get the original unwrapped module
self.unwrapped_model = self.fsdp_module._fsdp_wrapped_module
if self._use_optimizer_offload and not self._is_critic:
offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After offload actor optimizer during init")
if self._is_actor:
self.actor = DataParallelPPOActor(
config=self.config.actor,
actor_module=self.fsdp_module,
actor_optimizer=self.optimizer,
)
if self._is_critic:
self.critic = DataParallelPPOCritic(
config=self.config,
critic_module=self.fsdp_module,
critic_optimizer=self.optimizer,
)
if self._is_rollout:
self._build_rollout()
if self._is_ref:
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.fsdp_module)
if self._is_actor or self._is_critic:
self.flops_counter = FlopsCounter(self.model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.fsdp_module,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
tokenizer=self.tokenizer,
processor=self.processor
)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, path: str, global_step: int = 0, remove_previous_ckpt: bool = False):
assert self._is_actor or self._is_critic
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
self.checkpoint_manager.save_checkpoint(
local_path=path,
global_step=global_step,
remove_previous_ckpt=remove_previous_ckpt,
)
dist.barrier()
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, path: str, del_local_after_load: bool = True):
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
self.checkpoint_manager.load_checkpoint(path=path, del_local_after_load=del_local_after_load)
dist.barrier()
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
"""ActorRolloutRefWorker"""
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
assert self._is_actor
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload:
load_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("Before update policy")
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
log_gpu_memory_usage("After update policy")
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
assert self._is_rollout
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
meta_info = {
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
# after parameters sync with rollout, offload actor model to CPU
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
log_gpu_memory_usage("After entering rollout sharding manager")
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage("After rollout generation")
output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to("cpu")
torch.cuda.empty_cache() # clear kv cache
log_gpu_memory_usage("After recompute log prob")
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(
tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.fsdp_module._handle.reshard(True)
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache()
log_gpu_memory_usage("After compute_log_prob")
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
data.meta_info["temperature"] = self.config.rollout.temperature
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.fsdp_module._handle.reshard(True)
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache()
log_gpu_memory_usage("After compute_ref_log_prob")
return output
"""CriticWorker"""
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
assert self._is_critic
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
if self._use_param_offload:
load_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload:
load_fsdp_optimizer(optimizer=self.optimizer)
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["mfu/critic"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)
torch.cuda.empty_cache()
return output
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import RewardConfig
from .custom import CustomRewardManager
__all__ = ["CustomRewardManager", "RewardConfig"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reward config
"""
from dataclasses import dataclass
@dataclass
class RewardConfig:
reward_type: str = "function"
compute_score: str = "math"
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import PreTrainedTokenizer
from verl import DataProto
from verl.utils.reward_score import math_compute_score, r1v_compute_score
class CustomRewardManager:
def __init__(self, tokenizer: PreTrainedTokenizer, num_examine: int, compute_score: str):
self.tokenizer = tokenizer
self.num_examine = num_examine
if compute_score == "math":
self.compute_score = math_compute_score
elif compute_score == "r1v":
self.compute_score = r1v_compute_score
else:
raise NotImplementedError()
def __call__(self, data: DataProto) -> torch.Tensor:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
already_print = 0
for i in range(len(data)):
data_item = data[i] # DataProtoItem
prompt_ids = data_item.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
valid_prompt_ids = prompt_ids[-valid_prompt_length:]
response_ids = data_item.batch["responses"]
valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
ground_truth = data_item.non_tensor_batch["answer"]
score = self.compute_score(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score
if already_print < self.num_examine:
already_print += 1
print("[prompt]", prompt_str)
print("[response]", response_str)
print("[ground_truth]", ground_truth)
print("[score]", score)
return reward_tensor
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import RolloutConfig
__all__ = ["RolloutConfig"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from verl import DataProto
__all__ = ["BaseRollout"]
class BaseRollout(ABC):
@abstractmethod
def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Generate sequences"""
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment