# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re from typing import List, Union, Optional import copy import datasets from collections import defaultdict import torch import numpy as np from torch.utils.data import Dataset from transformers import PreTrainedTokenizer, ProcessorMixin from omegaconf import ListConfig, DictConfig from verl.utils.model import compute_position_id_with_mask import verl.utils.torch_functional as verl_F def collate_fn(data_list: list[dict]) -> dict: tensors = defaultdict(list) non_tensors = defaultdict(list) for data in data_list: for key, val in data.items(): if isinstance(val, torch.Tensor): tensors[key].append(val) else: non_tensors[key].append(val) for key, val in tensors.items(): tensors[key] = torch.stack(val, dim=0) for key, val in non_tensors.items(): non_tensors[key] = np.array(val, dtype=object) return {**tensors, **non_tensors} class RLHFDataset(Dataset): """ We assume the dataset contains a column that contains prompts and other information """ def __init__( self, data_files: Union[str, List[str]], tokenizer: PreTrainedTokenizer, config: DictConfig, processor: Optional[ProcessorMixin] = None, ): if not isinstance(data_files, (List, ListConfig)): data_files = [data_files] self.data_files = copy.deepcopy(data_files) self.original_data_files = copy.deepcopy(data_files) # use for resume self.tokenizer = tokenizer self.processor = processor self.config = config self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) self.prompt_key = config.get("prompt_key", "prompt") self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") self.max_prompt_length = config.get("max_prompt_length", 1024) self.return_raw_chat = config.get('return_raw_chat', False) self.truncation = config.get('truncation', 'error') self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) self.num_workers = min(self.num_workers, os.cpu_count()) # whether to store the dataset in state_dict() # default not store self.serialize_dataset = False self._download() self._read_files_and_tokenize() def _download(self, use_origin_parquet=False): from verl.utils.fs import copy_to_local data_files = self.data_files if not use_origin_parquet else self.original_data_files for i, parquet_file in enumerate(data_files): self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir) def _read_files_and_tokenize(self): dataframes = [] for parquet_file in self.data_files: # read parquet files and cache dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) print(f'dataset len: {len(self.dataframe)}') # filter out too long prompts if self.filter_overlong_prompts: tokenizer = self.tokenizer prompt_key = self.prompt_key self.dataframe = self.dataframe.filter( lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) ) <= self.max_prompt_length, num_proc=self.num_workers, desc=f"Filtering prompts longer than {self.max_prompt_length} tokens") print(f'filter dataset len: {len(self.dataframe)}') def resume_dataset_state(self): self.serialize_dataset = False if hasattr(self, 'original_data_files') else True # resume dataframe if not it's serialized in data.pt if not self.serialize_dataset: self._download(use_origin_parquet=True) # download and resume from original parquet files self._read_files_and_tokenize() else: print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance') def __len__(self): return len(self.dataframe) def _build_messages(self, example: dict): messages: list = example.pop(self.prompt_key) if self.image_key in example or self.video_key in example: for message in messages: content = message["content"] content_list = [] for segment in re.split("(|