Commit 9bce6a82 authored by mashun1's avatar mashun1
Browse files

huatuogpt-o1

parents
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import math
import os
import textwrap
import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Tuple, Union
import random,re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import broadcast, gather_object
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
DataCollatorWithPadding,
FeatureExtractionMixin,
GenerationConfig,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainerControl,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_peft_available
from transformers.utils.deprecation import deprecate_kwarg
from trl.core import masked_mean, masked_whiten
from trl.models import create_reference_model
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
exact_div,
first_true_indices,
forward,
get_reward,
prepare_deepspeed,
print_rich_table,
truncate_response,
)
from trl.trainer.ppo_config import PPOConfig
from trl.trainer.utils import generate_model_card, peft_module_casting_to_bf16
if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model
if is_wandb_available():
import wandb
INVALID_LOGPROB = 1.0
# for o1
accumulate_rewards = []
# Using our get_reward
def get_reward_o1(
model, response_ids, tokenizer, reward_tokenizer, pad_token_id, sub_answer,max_length = 4000
) -> Tuple[torch.Tensor]:
tmp = """<Model Response>
{}
</Model Response>
<Reference Answer>
{}
</Reference Answer>
Your task is to evaluate the model response by comparing it to the reference answer. If the model response is correct and aligns with the reference answer, output "True" . If it is incorrect or fails to select the correct option (if options are provided), output "False" . {}"""
output_pattern = r"## Final Response\n\n(.*)"
processed_batch = []
output_matches = []
for i in range(len(sub_answer)):
response = tokenizer.decode(response_ids[i], skip_special_tokens=True)
count_en = response.count('## Final Response\n\n')
count_thinking_en = response.count('## Thinking')
if '## Final Response\n\n' in response and count_en == 1 and count_thinking_en == 1:
output_match = re.search(output_pattern, response, re.S)
else:
output_match = None
output_matches.append(output_match)
if output_match is None:
response = 'I do not know the answer.'
else:
response = output_match.group(1).strip()
format_response = tmp.format(response, sub_answer[i], reward_tokenizer.eos_token)
processed_batch.append(format_response)
input_batch = reward_tokenizer(processed_batch, return_tensors="pt", add_special_tokens=False, max_length=max_length, padding=True,truncation=True).to(model.device)
with torch.no_grad():
logits = model(**input_batch,return_dict=True).logits
probabilities = F.softmax(logits, dim=-1)
rewards = probabilities[:, 1] * 10
rewards_list = []
for i in range(len(sub_answer)):
if output_matches[i] is None:
rewards_list.append(0.0)
else:
p = probabilities[i, 1].item()
if p > 0.4:
rewards_list.append(1.0)
else:
rewards_list.append(0.1)
rewards = torch.tensor(rewards_list, device=probabilities.device, dtype=probabilities.dtype)
# Update global reward statistics
global accumulate_rewards
accumulate_rewards.append(rewards.sum().item() / len(processed_batch))
# Debugging rewards
if random.random() < 0.1:
for ii in range(len(processed_batch)):
print('[reward_input]',processed_batch[ii],flush=True)
print('[reward]',rewards[ii].item(),'\n',flush=True)
print('-----------[avg_rewards]----------',sum(accumulate_rewards[-50:])/len(accumulate_rewards[-50:]),'\n',flush=True)
return rewards
# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29
# we did this we can do a single `model = accelerator.prepare(model)`
class PolicyAndValueWrapper(nn.Module):
def __init__(self, policy, value_model) -> None:
super().__init__()
self.policy = policy
self.value_model = value_model
self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
def forward(self, **kwargs):
output = self.critic_backbone(
**kwargs,
)
logits = self.value_model.score(output.hidden_states[-1])
return self.policy(**kwargs), logits
class PPOTrainer(Trainer):
_tag_names = ["trl", "ppo"]
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
def __init__(
self,
config: PPOConfig,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
],
reward_processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
],
policy: nn.Module,
ref_policy: Optional[nn.Module],
reward_model: nn.Module,
train_dataset: Dataset,
value_model: Optional[nn.Module] = None,
data_collator: Optional[DataCollatorWithPadding] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
# less commonly used
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[List[TrainerCallback]] = None,
peft_config: Optional["PeftConfig"] = None,
) -> None:
if ref_policy is policy:
raise ValueError(
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
"same as `policy`, you must make a copy of it, or `None` if you use peft."
)
self.args = config
args = config
self.processing_class = processing_class
self.reward_processing_class = reward_processing_class
self.policy = policy
# Define the collator if not provided
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)
self.policy.generation_config.eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
# peft support
if not is_peft_available() and peft_config is not None:
raise ImportError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_confg, we merge and unload it first
if isinstance(self.policy, PeftModel):
self.policy = self.policy.merge_and_unload()
# get peft model with the given config
self.policy = get_peft_model(self.policy, peft_config)
if args.bf16 and getattr(self.policy, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(self.policy)
self.is_peft_model = is_peft_available() and isinstance(self.policy, PeftModel)
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
if ref_policy:
self.ref_policy = ref_policy
elif self.is_peft_model:
self.ref_policy = None
else:
self.ref_policy = create_reference_model(self.policy)
self.reward_model = reward_model
self.train_dataset = train_dataset
self.train_dataset_len = len(train_dataset)
self.value_model = value_model
self.data_collator = data_collator
self.eval_dataset = eval_dataset
self.optimizer, self.lr_scheduler = optimizers
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
#########
# calculate various batch sizes
#########
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
self.accelerator = accelerator
args.world_size = accelerator.num_processes
args.local_batch_size = (
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
)
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
args.batch_size = int(args.local_batch_size * args.world_size)
args.mini_batch_size = exact_div(
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
)
args.local_mini_batch_size = exact_div(
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
)
if args.whiten_rewards:
assert (
args.local_mini_batch_size >= 8
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
args.num_total_batches = math.ceil(
args.total_episodes / args.batch_size
) # we may train for more than `total_episodes`
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
if args.num_sample_generations > 0:
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
self.local_dataloader_batch_size = args.local_batch_size
#########
# setup model, optimizer, and others
#########
for module in [self.policy, self.ref_policy, self.value_model, self.reward_model]:
if module is not None:
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
self.model = PolicyAndValueWrapper(self.policy, self.value_model)
self.model.config = self.policy.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
num_training_steps=args.num_total_batches
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
#########
### trainer specifics
#########
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# Create distant repo and output directory if needed
self.hub_model_id = None
if self.args.push_to_hub:
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
#########
### setup dataloader
#########
self.dataloader = DataLoader(
self.train_dataset,
batch_size=self.local_dataloader_batch_size,
shuffle=True,
collate_fn=self.data_collator,
drop_last=True, # needed; otherwise the last batch will be of ragged shape
)
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
torch.manual_seed(args.seed)
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
torch.manual_seed(self.local_seed) # reset the local seed again
self.eval_dataloader = DataLoader(
self.eval_dataset,
batch_size=args.per_device_eval_batch_size,
collate_fn=self.data_collator,
drop_last=True,
) # no need to shuffle eval dataset
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
if self.is_deepspeed_enabled:
self.reward_model = prepare_deepspeed(
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
)
if self.ref_policy is None:
if not self.is_peft_model:
raise ValueError("No reference model and model is not a Peft model.")
else:
self.ref_policy = prepare_deepspeed(
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
)
else:
print("not using deepspeed!!!!!!!!",flush=True)
if self.ref_policy is None:
if not self.is_peft_model:
raise ValueError("No reference model and model is not a Peft model.")
else:
self.ref_policy = self.ref_policy.to(self.accelerator.device)
self.reward_model = self.reward_model.to(self.accelerator.device)
def get_train_dataloader(self) -> DataLoader:
return self.dataloader
def get_eval_dataloader(self) -> DataLoader:
return self.eval_dataloader
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model.policy
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.policy.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.policy.set_adapter(self.model_adapter_name or "default")
# fix the save_model bug
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
backup_model = self.model
self.model = self.model.policy # save only the policy
Trainer.save_model(self, output_dir, _internal_call)
self.model = backup_model
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if self.is_deepspeed_enabled:
state_dict = {name.removeprefix('policy.'): param for name, param in state_dict.items()
if name.startswith('policy.')}
super()._save(output_dir, state_dict)
def train(self):
args = self.args
accelerator = self.accelerator
optimizer = self.optimizer
model = self.model
ref_policy = self.ref_policy
reward_model = self.reward_model
processing_class = self.processing_class
dataloader = self.dataloader
device = accelerator.device
def repeat_generator():
while True:
yield from dataloader
iter_dataloader = iter(repeat_generator())
generation_config = GenerationConfig(
max_new_tokens=args.response_length,
temperature=(args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
accelerator.print("===training policy===")
start_time = time.time()
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
pg_loss_stats = torch.zeros(stats_shape, device=device)
vf_loss_stats = torch.zeros(stats_shape, device=device)
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
entropy_stats = torch.zeros(stats_shape, device=device)
ratio_stats = torch.zeros(stats_shape, device=device)
model.train()
# trainer state initialization
self.state.global_step = 0
self.state.episode = 0
self.state.max_steps = args.num_total_batches * args.num_mini_batches
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model
self.model_wrapped = self.model
for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
with torch.no_grad():
queries = data["input_ids"].to(device)
allanswer = data["answer"]
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
logprobs = []
ref_logprobs = []
scores = []
sequence_lengths = []
values = []
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
sub_answer = allanswer[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
torch.cuda.empty_cache()
if ref_policy is None:
with self.null_ref_context():
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
else:
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
)
# Response Processing 2. run reward model on the truncated responses
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
unwrapped_value_model = accelerator.unwrap_model(model).value_model
full_value, _, _ = get_reward(
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
)
value = full_value[:, context_length - 1 : -1].squeeze(-1)
score = get_reward_o1(
reward_model, postprocessed_response, processing_class, self.reward_processing_class, processing_class.pad_token_id, sub_answer
)
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
values.append(value)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
torch.cuda.empty_cache()
gc.collect()
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
# Completions not passing that filter will receive a lower score.
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
if self.args.missing_eos_penalty is not None:
scores[~contain_eos_token] -= self.args.missing_eos_penalty
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
sequence_lengths_p1 = sequence_lengths + 1
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
values = torch.masked_fill(values, padding_mask_p1, 0)
# 4. compute rewards
kl = logprobs - ref_logprobs
non_score_reward = -args.kl_coef * kl
rewards = non_score_reward.clone()
actual_start = torch.arange(rewards.size(0), device=rewards.device)
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
rewards[[actual_start, actual_end]] += scores
# 5. whiten rewards
if args.whiten_rewards:
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
# 6. compute advantages and returns
lastgaelam = 0
advantages_reversed = []
gen_length = responses.shape[1]
for t in reversed(range(gen_length)):
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
lastgaelam = delta + args.gamma * args.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
torch.cuda.empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
b_inds = np.random.permutation(args.local_batch_size)
minibatch_idx = 0
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]
mb_return = returns[micro_batch_inds]
mb_values = values[micro_batch_inds]
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
vpredclipped = torch.clamp(
vpred,
mb_values - args.cliprange_value,
mb_values + args.cliprange_value,
)
vf_losses1 = torch.square(vpred - mb_return)
vf_losses2 = torch.square(vpredclipped - mb_return)
vf_loss_max = torch.max(vf_losses1, vf_losses2)
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
vf_clipfrac = masked_mean(
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
loss = pg_loss + args.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
pg_clipfrac = masked_mean(
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
)
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
pg_clipfrac
)
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
vf_clipfrac
)
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
# del everything and empty cache
# fmt: off
del (
output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
mean_non_score_reward = non_score_reward.sum(1).mean()
rlhf_reward = mean_non_score_reward + scores.mean()
eps = int(self.state.episode / (time.time() - start_time))
metrics = {}
metrics["eps"] = eps
metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item()
metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item()
metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item()
metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item()
metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item()
metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item()
metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item()
metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item()
metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item()
metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item()
metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item()
metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item()
metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item()
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
metrics["episode"] = self.state.episode
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
self.state.global_step += 1
self.log(metrics)
self.lr_scheduler.step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
torch.cuda.empty_cache()
del (
query_responses,
responses,
postprocessed_responses,
logprobs,
ref_logprobs,
values,
sequence_lengths,
contain_eos_token,
sequence_lengths_p1,
response_idxs,
padding_mask,
padding_mask_p1,
rewards,
actual_start,
actual_end,
advantages,
returns,
)
torch.cuda.empty_cache()
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
if self.control.should_save:
self._save_checkpoint(model, trial=None, metrics=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
def generate_completions(self, sampling: bool = False):
args = self.args
processing_class = self.processing_class
generation_config = GenerationConfig(
max_new_tokens=self.args.response_length,
temperature=(0.01 + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)
table = defaultdict(list)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch["input_ids"]
with torch.no_grad():
context_length = query.shape[1]
query_response, _ = batch_generation(
unwrapped_model.policy,
query,
query.shape[0],
processing_class.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
)
table["query"].extend(
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
)
table["model response"].extend(
gather_object(processing_class.batch_decode(postprocessed_response))
)
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
_, score, _ = get_reward(
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)
table["score"].extend(self.accelerator.gather(score).float().cpu().numpy())
if sampling:
break
df = pd.DataFrame(table)
if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, List[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*, defaults to `None`):
The name of the model.
dataset_name (`str`, *optional*, defaults to `None`):
The name of the dataset used for training.
tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
tags = tags or []
if isinstance(tags, str):
tags = [tags]
if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")
citation = textwrap.dedent("""\
@article{mziegler2019fine-tuning,
title = {{Fine-Tuning Language Models from Human Preferences}},
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
year = 2019,
eprint = {arXiv:1909.08593}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
trainer_name="PPO",
trainer_citation=citation,
paper_title="Fine-Tuning Language Models from Human Preferences",
paper_id="1909.08593",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))
trl==0.13.0
accelerate==0.34.2
# torch==2.5.1
transformers==4.46.2
# deepspeed==0.15.4
# xformers==0.0.28
# vllm==0.6.4
# torchvision==0.20.1
retrying
\ No newline at end of file
# %%
import os
import random
import json
from tqdm import tqdm
import multiprocessing
from multiprocessing import Pool
from concurrent.futures import ThreadPoolExecutor
import random
import requests
from retrying import retry
import argparse
import re
import traceback
import copy
class GPT:
def __init__(self, model_name, api_url, api_key):
self.model_name = model_name
self.api_url = api_url
self.api_key = api_key
print(f"Using model: {self.model_name}")
def call(self, content, additional_args={}):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": self.model_name,
"messages": [{'role': 'user', 'content': content}],
**additional_args,
}
response = requests.post(self.api_url, headers=headers, json=payload)
response_data = response.json()
if 'error' in response_data:
raise ValueError(f"API Error: {response_data}")
return response_data['choices'][0]['message']['content']
@retry(wait_fixed=3000, stop_max_attempt_number=3)
def retry_call(self, content, additional_args={"max_tokens": 8192}):
return self.call(content, additional_args)
verify_prompt = """<Model Response>
{}
</Model Response>
<Reference Answer>
{}
</Reference Answer>
You are provided with a model-generated response (<Model Response>) and a reference answer (<Reference Answer>). Compare the model response with the reference answer and determine its correctness. Your task is to simply output "True" if the response is correct, and "False" otherwise."""
query_prompt_init = """<question>
{}
</question>
Please respond to the above question <question> using the Chain of Thought (CoT) reasoning method. Your response should consist of multiple steps, each of which includes three types of actions: **"Inner Thinking"**, **"Final Conclusion"**, and **"Verification"**:
- **'Inner Thinking'**: This is the step where thinking is done. Note that multiple 'Inner Thinking' steps are required to describe thorough reasoning. Each step should first generate a brief title.
- **'Final Conclusion'**: At this stage, you summarize the correct reasoning from previous 'Inner Thinking' steps and provide the final answer. No title is required here.
- **'Verification'**: At this stage, you verify the conclusion from the "Final Conclusion" step. If the conclusion holds, end the process. If not, return to "Inner Thinking" for further reasoning. No title is required here.
The output format must strictly follow the JSON structure below:
```json
{{
"CoT": [
{{"action": "Inner Thinking", "title": "...", "content": "..."}},
...,
{{"action": "Final Conclusion", "content": "..."}},
{{"action": "Verification", "content": "..."}}
]
}}
```"""
gen_prompt_rethink_Backtracking = """<question>
{}
</question>
<previous reasoning>
{}
<previous reasoning>
<response requirements>
Your response must include the following steps, each composed of three types of actions: **"Inner Thinking"**, **"Final Conclusion"**, and **"Verification"**:
1. **Inner Thinking**: Break down the reasoning process into multiple concise steps. Each step should start with a brief title to clarify its purpose.
2. **Final Conclusion**: Summarize the correct reasoning from all previous 'Inner Thinking' steps and provide the final answer. No title is needed for this section.
3. **Verification**: Verify the accuracy of the "Final Conclusion". If it holds, conclude the process. Otherwise, return to "Inner Thinking" for further refinement.
</response requirements>
<question> represents the question to be answered, and <previous reasoning> contains your prior reasoning. Your task is to continue from the current 'Verification' step. I have manually reviewed the reasoning and determined that the **Final Conclusion** is false. Your 'Verification' results must align with mine. Proceed to refine the reasoning using **backtracking** to revisit earlier points of reasoning and construct a new Final Conclusion.
### Output Format
Strictly follow the JSON structure below. You do not need to repeat your previous reasoning. Begin directly from the next 'Verification' stage.
```json
{{
"CoT": [
{{"action": "Verification", "content": "..."}},
{{"action": "Inner Thinking", "title": "...", "content": "..."}},
...,
{{"action": "Final Conclusion", "content": "..."}},
{{"action": "Verification", "content": "..."}}
]
}}
```"""
gen_prompt_rethink_Exploring_New_Path = """<question>
{}
</question>
<previous reasoning>
{}
<previous reasoning>
<response requirements>
Your response must include the following steps, each composed of three types of actions: **"Inner Thinking"**, **"Final Conclusion"**, and **"Verification"**:
1. **Inner Thinking**: Break down the reasoning process into multiple concise steps. Each step should start with a brief title to clarify its purpose.
2. **Final Conclusion**: Summarize the correct reasoning from all previous 'Inner Thinking' steps and provide the final answer. No title is needed for this section.
3. **Verification**: Verify the accuracy of the "Final Conclusion". If it holds, conclude the process. Otherwise, return to "Inner Thinking" for further refinement.
</response requirements>
<question> represents the question to be answered, and <previous reasoning> contains your prior reasoning. Your task is to continue from the current 'Verification' step. I have manually reviewed the reasoning and determined that the **Final Conclusion** is false. Your 'Verification' results must align with mine. Proceed to refine the reasoning by exploring new approaches to solving this problem and construct a new Final Conclusion.
### Output Format
Strictly follow the JSON structure below. You do not need to repeat your previous reasoning. Begin directly from the next 'Verification' stage.
```json
{{
"CoT": [
{{"action": "Verification", "content": "..."}},
{{"action": "Inner Thinking", "title": "...", "content": "..."}},
...,
{{"action": "Final Conclusion", "content": "..."}},
{{"action": "Verification", "content": "..."}}
]
}}
```"""
gen_prompt_rethink_Verification = """<question>
{}
</question>
<previous reasoning>
{}
<previous reasoning>
<response requirements>
Your response must include the following steps, each composed of three types of actions: **"Inner Thinking"**, **"Final Conclusion"**, and **"Verification"**:
1. **Inner Thinking**: Break down the reasoning process into multiple concise steps. Each step should start with a brief title to clarify its purpose.
2. **Final Conclusion**: Summarize the correct reasoning from all previous 'Inner Thinking' steps and provide the final answer. No title is needed for this section.
3. **Verification**: Verify the accuracy of the "Final Conclusion". If it holds, conclude the process. Otherwise, return to "Inner Thinking" for further refinement.
</response requirements>
<question> represents the question to be answered, and <previous reasoning> contains your prior reasoning. Your task is to continue from the current 'Verification' step. I have manually reviewed the reasoning and determined that the **Final Conclusion** is false. Your 'Verification' results must align with mine. Proceed to refine the reasoning by conducting a thorough **validation** process to ensure validity and construct a new Final Conclusion.
### Output Format
Strictly follow the JSON structure below. You do not need to repeat your previous reasoning. Begin directly from the next 'Verification' stage.
```json
{{
"CoT": [
{{"action": "Verification", "content": "..."}},
{{"action": "Inner Thinking", "title": "...", "content": "..."}},
...,
{{"action": "Final Conclusion", "content": "..."}},
{{"action": "Verification", "content": "..."}}
]
}}
```"""
gen_prompt_rethink_Correction = """<question>
{}
</question>
<previous reasoning>
{}
<previous reasoning>
<response requirements>
Your response must include the following steps, each composed of three types of actions: **"Inner Thinking"**, **"Final Conclusion"**, and **"Verification"**:
1. **Inner Thinking**: Break down the reasoning process into multiple concise steps. Each step should start with a brief title to clarify its purpose.
2. **Final Conclusion**: Summarize the correct reasoning from all previous 'Inner Thinking' steps and provide the final answer. No title is needed for this section.
3. **Verification**: Verify the accuracy of the "Final Conclusion". If it holds, conclude the process. Otherwise, return to "Inner Thinking" for further refinement.
</response requirements>
<question> represents the question to be answered, and <previous reasoning> contains your prior reasoning. Your task is to continue from the current 'Verification' step. I have manually reviewed the reasoning and determined that the **Final Conclusion** is false. Your 'Verification' results must align with mine. Proceed to refine the reasoning by making precise **corrections** to address prior flaws and construct a new Final Conclusion.
### Output Format
Strictly follow the JSON structure below. You do not need to repeat your previous reasoning. Begin directly from the next 'Verification' stage.
```json
{{
"CoT": [
{{"action": "Verification", "content": "..."}},
{{"action": "Inner Thinking", "title": "...", "content": "..."}},
...,
{{"action": "Final Conclusion", "content": "..."}},
{{"action": "Verification", "content": "..."}}
]
}}
```"""
gen_prompt_w_label = """<question>
{}
</question>
<previous reasoning>
{}
</previous reasoning>
<response requirements>
Your response must include the following steps, each composed of three types of actions: **"Inner Thinking"**, **"Final Conclusion"**, and **"Verification"**:
1. **Inner Thinking**: Break down the reasoning process into multiple concise steps. Each step should start with a brief title to clarify its purpose.
2. **Final Conclusion**: Summarize the correct reasoning from all previous 'Inner Thinking' steps and provide the final answer. No title is needed for this section.
3. **Verification**: Verify the accuracy of the "Final Conclusion". If it holds, conclude the process. Otherwise, return to "Inner Thinking" for further refinement.
</response requirements>
<question> represents the question to be answered, and <previous reasoning> contains your prior reasoning. Your task is to continue from the current 'Verification' step. Now, I'll secretly tell you that the labeled answer is "{}", but you must pretend not to know. Your 'Verification' requires careful consideration, and if incorrect, you need to provide new Inner Thinking steps and a new Final Conclusion to ensure the final answer aligns with the correct one.
### Output Format
Strictly follow the JSON structure below. You do not need to repeat your previous reasoning. Begin directly from the next 'Verification' stage.
```json
{{
"CoT": [
{{"action": "Verification", "content": "..."}},
{{"action": "Inner Thinking", "title": "...", "content": "..."}},
...,
{{"action": "Final Conclusion", "content": "..."}},
{{"action": "Verification", "content": "..."}}
]
}}
```"""
reformat_to_complex_cot_prompt = """<Thought Process>
{}
</Thought Process>
<Question>
{}
</Question>
The <Thought Process> above reflects the model's reasoning based on the <Question>. Your task is to rewrite the <Thought Process> to resemble a more human-like, intuitive natural thinking process. The new version should:
1. Be presented as step-by-step reasoning, with each thought on a new line separated by a line break.
2. Avoid structured titles or formatting, focusing on natural transitions. Use casual and natural language for transitions or validations, such as "hmm," "oh," "also," or "wait."
3. Expand the content, making the reasoning richer, more detailed, and logically clear while still being conversational and intuitive.
Return directly the revised natural thinking in JSON format as follows:
```json
{{
"NaturalReasoning": "..."
}}
```"""
get_final_response_prompt = """<Internal Thinking>
{}
</Internal Thinking>
<Question>
{}
</Question>
The <Internal Thinking> represents your internal thoughts about the <Question>. Based on this, generate a rich and high-quality final response to the user. If there is a clear answer, provide it first. Ensure your final response closely follows the <Question>. The response style should resemble GPT-4's style as much as possible. Output only your final response, without any additional content."""
# search strategies
search_strategies = [('Backtracking',gen_prompt_rethink_Backtracking),('Exploring New Paths',gen_prompt_rethink_Exploring_New_Path),('Verification',gen_prompt_rethink_Verification),('Correction',gen_prompt_rethink_Correction)]
def extract_bracket_content(text):
# Extract content between the first '{' and the last '}'
match = re.search(r'\{.*\}', text, re.DOTALL)
return match.group(0) if match else None
def parse_gpt_response(response):
try:
if '{' != response[0]:
response = extract_bracket_content(response)
da = json.loads(response.replace('\n',''))
assert isinstance(da["CoT"],list), "CoT should be list"
assert da['CoT'][-3]['action'] == 'Inner Thinking', 'Inner Thinking should be the third last action'
assert da['CoT'][-2]['action'] == 'Final Conclusion', 'Final Conclusion should be the second last action'
assert da['CoT'][-1]['action'] == 'Verification', 'Verification should be the last action'
return True,da
except Exception as e:
print(e)
traceback.print_exc()
return False,None
def parse_gpt_response_reformat(response):
try:
if '{' != response[0]:
response = extract_bracket_content(response)
da = json.loads(response.replace('\n',''))
assert isinstance(da["NaturalReasoning"],str), "NaturalReasoning should be str"
assert '\n' in da["NaturalReasoning"], "NaturalReasoning should have \\n"
return True,da
except Exception as e:
print(e)
traceback.print_exc()
return False,None
def get_stream_of_search(longcot):
temp = '### {}\n{}\n'
resstr = []
for x in longcot:
if 'title' in x:
resstr.append(temp.format(x['title'],x['content']))
else:
resstr.append(temp.format(x['action'].replace('Final Conclusion','Conclusion'),x['content']))
return '\n'.join(resstr).strip()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True, help="Path to the input JSON data file.")
parser.add_argument("--model_name", type=str, default="gpt-4", help="Name of the GPT model to use.")
parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key.")
parser.add_argument("--api_url", type=str, default="https://api.openai.com/v1/chat/completions", help="OpenAI API URL.")
parser.add_argument("--max_search_attempts", type=int, default=1, help="Maximum number of search attempts.")
parser.add_argument("--max_search_depth", type=int, default=2, help="Maximum search depth.")
parser.add_argument("--efficient_search", type=bool, default=True, help="Enable efficient search strategy.")
parser.add_argument("--num_process", type=int, default=5, help="Number of parallel processes.")
parser.add_argument("--limit_num", type=int, help="Limit the number of processed items.")
args = parser.parse_args()
def filter_data(tmpdata):
filtered_data = []
for da in tmpdata:
if 'Open-ended Verifiable Question' not in da or 'Ground-True Answer' not in da:
continue
filtered_data.append(da)
print(f"Original data size: {len(tmpdata)}, Filtered data size: {len(filtered_data)}")
return filtered_data
with open(args.data_path) as f:
tmpdata = json.load(f)
tmp_id = 1
for da in tmpdata:
da['process_id'] = tmp_id
tmp_id += 1
data = filter_data(tmpdata)
if args.limit_num:
data = data[:args.limit_num]
print(f"read data:{len(data)}")
task_name = f'{os.path.split(args.data_path)[-1].replace(".json","")}_CoT_search'
save_dir = f'output_data/{task_name}'
gpt_instance = GPT(model_name=args.model_name, api_url=args.api_url, api_key=args.api_key)
def verify_gpt(conclusion,answer,d):
query = verify_prompt.format(conclusion,answer)
response = gpt_instance.retry_call(query)
d['gpt4_query_cot'].append(query)
d['gpt4_response_cot'].append(response)
if 'true' in response.lower():
d['verify'].append(True)
return True
else:
d['verify'].append(False)
return False
global wrongtime
wrongtime = 0
def write_piece_order_data(d):
global wrongtime
try:
retry_time = 1
d['verify'] = []
d['Long_CoT'] = []
d['gpt4_query_cot'] = []
d['gpt4_response_cot'] = []
d['response_struct'] = []
d['response_type'] = []
d['prior_fail_try'] = []
save_path = os.path.join(save_dir, str(d['process_id']) + ".json")
# init reason
query = query_prompt_init.format(d['Open-ended Verifiable Question'])
d['gpt4_query_cot'].append(query)
for ii in range(retry_time):
response = gpt_instance.retry_call(query)
if ii == 0:
d['gpt4_response_cot'].append(response)
flag, struct = parse_gpt_response(response)
if flag:
d['response_struct'].append(struct["CoT"])
d['Long_CoT'] = struct["CoT"]
d['response_type'].append('Init_CoT')
break
else:
print(f'retrying Init_CoT',flush=True)
if not flag:
raise Exception('init error')
verify_gpt(d['Long_CoT'][-2]['content'],d['Ground-True Answer'],d)
for rethinking_try_time in range(args.max_search_attempts):
if rethinking_try_time > 0:
# Archive the failed state
del d['prior_fail_try']
save_d['prior_fail_try'].append(d)
# Replace with a new state
d = save_d
# Save the initial state
save_d = copy.deepcopy(d)
# Begin search
for rethink_time in range(args.max_search_depth):
if d['verify'][-1]:
break
reasoning = json.dumps(d['Long_CoT'][:-1],ensure_ascii=False,indent=2)
# Search strategy
if rethink_time > 0:
strategy_name,strategy = random.choice(search_strategies)
else:
# exclude Backtracking
strategy_name,strategy = random.choice(search_strategies[1:])
query = strategy.format(d['Open-ended Verifiable Question'],reasoning)
d['gpt4_query_cot'].append(query)
for ii in range(retry_time):
response = gpt_instance.retry_call(query)
flag, struct = parse_gpt_response(response)
if flag:
d['gpt4_response_cot'].append(response)
d['response_struct'].append(struct["CoT"])
d['Long_CoT'] = d['Long_CoT'][:-1] + struct["CoT"]
d['response_type'].append(f'Re_CoT_{strategy_name}')
break
else:
print(f'retrying strategy {strategy_name}',flush=True)
if not flag:
raise Exception('rethink error')
verify_gpt(d['Long_CoT'][-2]['content'],d['Ground-True Answer'],d)
if d['verify'][-1]:
break
# If it is still incorrect, generate a final Label_CoT round
if not d['verify'][-1] and args.efficient_search:
reasoning = json.dumps(d['Long_CoT'][:-1],ensure_ascii=False,indent=2)
query = gen_prompt_w_label.format(d['Open-ended Verifiable Question'],reasoning,d['Ground-True Answer'])
d['gpt4_query_cot'].append(query)
for ii in range(retry_time):
response = gpt_instance.retry_call(query)
flag, struct = parse_gpt_response(response)
if flag:
d['gpt4_response_cot'].append(response)
d['response_struct'].append(struct["CoT"])
d['Long_CoT'] = d['Long_CoT'][:-1] + struct["CoT"]
d['response_type'].append('Label_CoT')
# ignore verify
d['verify'].append(True)
break
else:
print(f'retrying Label_CoT',flush=True)
if not flag:
raise Exception('label error')
if d['verify'][-1]:
# Generate complex CoT and final response (Complex_CoT, response)
sos = get_stream_of_search(d['Long_CoT'])
query = reformat_to_complex_cot_prompt.format(sos,d['Open-ended Verifiable Question'])
d['gpt4_query_cot'].append(query)
for ii in range(retry_time):
response = gpt_instance.retry_call(query)
flag, struct = parse_gpt_response_reformat(response)
if flag:
d['gpt4_response_cot'].append(response)
d["Complex_CoT"] = struct["NaturalReasoning"]
# get response
query = get_final_response_prompt.format(d['Complex_CoT'],d['Open-ended Verifiable Question'])
d['gpt4_query_cot'].append(query)
response = gpt_instance.retry_call(query)
d['gpt4_response_cot'].append(response)
d["Response"] = response
d['Question'] = d['Open-ended Verifiable Question']
break
with open(save_path, mode="w", encoding="utf-8") as fw:
json.dump(d, fw, ensure_ascii=False,indent=2)
wrongtime = 0
except Exception as e:
traceback.print_exc()
wrongtime += 1
if wrongtime > 20:
assert 1 == 0, 'wrong'
return 1
def deduplicate_data(data, processed_data):
processed_ids = {item['process_id'] for item in processed_data}
return [item for item in data if item['process_id'] not in processed_ids]
def merge_saved_files(save_dir):
_, _, filenames = [i for i in os.walk(save_dir)][0]
json_files = [f for f in filenames if f.endswith('.json')]
res = []
for file_path in json_files:
try:
with open(os.path.join(save_dir, file_path), encoding="utf-8") as f:
da = json.loads(f.read())
assert 'Complex_CoT' in da and 'Response' in da
res.append(da)
except Exception as e:
continue
return res
os.makedirs(save_dir, exist_ok=True)
# Merge previously processed files
processed_data = merge_saved_files(save_dir)
print(f"Previously processed items: {len(processed_data)}")
input_data = deduplicate_data(data, processed_data)
print(f"Items remaining for processing: {len(input_data)}")
with ThreadPoolExecutor(max_workers=args.num_process) as executor:
list(tqdm(executor.map(write_piece_order_data, data), total=len(data), desc="Processing samples", unit="sample"))
# Merge and save final output
final_data = merge_saved_files(save_dir)
output_path = f"{task_name}_{len(final_data)}.json"
print(f"Processed {len(final_data)} items. Saving to {output_path}")
with open(output_path, 'w', encoding='utf-8') as file:
json.dump(final_data, file, ensure_ascii=False, indent=2)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment