# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ - Preprocess data and split the training set into 75% for training RM and 25% for validting RM. - All the training data is used to train SFT and RL. - Both chosen and rejected is used to train SFT """ import argparse import os import pandas as pd from datasets import load_dataset from tqdm.auto import tqdm from verl.utils.fs import copy, makedirs def generate_sft_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/sft"): dataset = load_dataset("Dahoas/full-hh-rlhf") output = {"prompt": [], "response": []} for data in tqdm(dataset["train"]): # add chosen output["prompt"].append(data["prompt"]) output["response"].append(data["chosen"]) # add rejection output["prompt"].append(data["prompt"]) output["response"].append(data["rejected"]) df = pd.DataFrame(output) local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) local_path = os.path.join(local_dir, "train.parquet") df.to_parquet(path=local_path) if target_hdfs_path_dir is not None: hdfs_dir = target_hdfs_path_dir + "/" + "train.parquet" makedirs(hdfs_dir) copy(local_path, hdfs_dir) def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm"): train_dataset = load_dataset("Dahoas/full-hh-rlhf", split="train[:75%]") test_dataset = load_dataset("Dahoas/full-hh-rlhf", split="train[-25%:]") local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) for dataset, name in zip([train_dataset, test_dataset], ["train", "test"], strict=True): output = {"prompt": [], "chosen": [], "rejected": []} for data in tqdm(dataset): # add chosen output["prompt"].append(data["prompt"]) output["chosen"].append(data["chosen"]) output["rejected"].append(data["rejected"]) df = pd.DataFrame(output) local_path = os.path.join(local_dir, name + ".parquet") df.to_parquet(path=local_path) if target_hdfs_path_dir is not None: hdfs_dir = target_hdfs_path_dir + "/" + name + ".parquet" makedirs(hdfs_dir) copy(local_path, hdfs_dir) def generate_rl_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlhf/rl"): dataset = load_dataset("Dahoas/full-hh-rlhf") train_dataset = dataset["train"] data_source = "Dahoas/full-hh-rlhf" # add a row to each data item that represents a unique id def make_map_fn(split): def process_fn(example, idx): prompt = example.pop("prompt") response = example.pop("response") data = { "data_source": data_source, "prompt": [{"role": "user", "content": prompt}], "ability": "alignment", "reward_model": { "style": "model", "ground_truth": response, # should not be used }, "extra_info": {"split": split, "index": idx}, } return data return process_fn train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) local_dir = os.path.expanduser(local_dir) local_path = os.path.join(local_dir, "train.parquet") train_dataset.to_parquet(local_path) if target_hdfs_path_dir is not None: hdfs_dir = target_hdfs_path_dir + "/" + "train.parquet" makedirs(hdfs_dir) copy(local_path, hdfs_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--split", type=str, choices=["sft", "rm", "rl"], required=True) parser.add_argument("--local_dir", type=str, default="~/data/full_hh_rlhf") parser.add_argument("--hdfs_dir", type=str, required=False, default=None) args = parser.parse_args() if args.split == "sft": generate_sft_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split)) elif args.split == "rm": generate_rm_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split)) elif args.split == "rl": generate_rl_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split)) else: raise NotImplementedError