Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
`datasketch` (https://github.com/ekzhu/datasketch) doesn't offer a way to store the hash tables in disk. This
is a custom implementation that uses `diskcache` to store the hash tables in disk.
Note: This implementation is not optimized for performance, but could be worth
creating a PR to `datasketch`.
"""
import shutil
import struct
from pathlib import Path
from typing import Callable, Dict, Final, Optional, Tuple
from datasketch import MinHashLSH as _MinHashLSH
from datasketch.lsh import _optimal_param
from datasketch.storage import OrderedStorage, UnorderedStorage, _random_name
from datasketch.storage import ordered_storage as _ordered_storage
from datasketch.storage import unordered_storage as _unordered_storage
KEY_VALUE_DISK_DIR: Path = Path.home() / ".cache" / "distilabel" / "key_value_store"
KV_DISK_LIST_NAME: Final[str] = "disckache_list_storage"
KV_DISK_SET_NAME: Final[str] = "diskcache_set_storage"
class DiskCacheListStorage(OrderedStorage):
def __init__(self, config, name) -> None:
path = config.get("path", self._get_db_name(name))
try:
from diskcache import Index
except ImportError as e:
raise ImportError(
"`diskcache` is required for disk storage using `MinHashDedup`. "
"Please install it using `pip install 'distilabel[minhash]'`."
) from e
# Start with a clean file on each pipeline
if Path(path).exists():
shutil.rmtree(path)
self._db = Index(path)
def _get_db_name(self, name):
return str(KEY_VALUE_DISK_DIR / f"{name}_{KV_DISK_LIST_NAME}")
def keys(self):
return self._db.keys()
def get(self, key):
return self._db.get(key, [])
def remove(self, *keys):
self._db.clear()
def remove_val(self, key, val):
self.get(key).remove(val)
def insert(self, key, *vals, **kwargs):
res = self.get(key)
res.extend(vals)
self._db[key] = res
def size(self):
return len(self._db)
def itemcounts(self, **kwargs):
return {k: len(v) for k, v in self._db.items()}
def has_key(self, key):
return key in self._db
def close(self):
self._db._cache.close()
class DiskCacheSetStorage(UnorderedStorage, DiskCacheListStorage):
def _get_db_name(self, name):
return str(KEY_VALUE_DISK_DIR / f"{name}_{KV_DISK_SET_NAME}")
def get(self, key):
return self._db.get(key, set())
def insert(self, key, *vals, **kwargs):
res = self.get(key)
res.update(vals)
self._db[key] = res
def ordered_storage(config, name=None):
"""Copy of `datasketch.storage.ordered_storage` with the addition of `DiskCacheListStorage`."""
tp = config["type"]
if tp == "disk":
return DiskCacheListStorage(config, name=name)
return _ordered_storage(config, name=name)
def unordered_storage(config, name=None):
"""Copy of `datasketch.storage.ordered_storage` with the addition of `DiskCacheSetStorage`."""
tp = config["type"]
if tp == "disk":
return DiskCacheSetStorage(config, name=name)
return _unordered_storage(config, name=name)
class MinHashLSH(_MinHashLSH):
"""Custom implementation of `datasketch.MinHashLSH` to allow passing a custom
storage configuration to store the hash tables in disk.
This could be merged in the original repository, the only changes
to the __init__ are the additional `close` method, and the use
of our custom `ordered_storage` and `unordered_storage` functions.
"""
def __init__(
self,
threshold: float = 0.9,
num_perm: int = 128,
weights: Tuple[float, float] = (0.5, 0.5),
params: Optional[Tuple[int, int]] = None,
storage_config: Optional[Dict] = None,
prepickle: Optional[bool] = None,
hashfunc: Optional[Callable[[bytes], bytes]] = None,
) -> None:
storage_config = {"type": "dict"} if not storage_config else storage_config
self._buffer_size = 50000
if threshold > 1.0 or threshold < 0.0:
raise ValueError("threshold must be in [0.0, 1.0]")
if num_perm < 2:
raise ValueError("Too few permutation functions")
if any(w < 0.0 or w > 1.0 for w in weights):
raise ValueError("Weight must be in [0.0, 1.0]")
if sum(weights) != 1.0:
raise ValueError("Weights must sum to 1.0")
self.h = num_perm
if params is not None:
self.b, self.r = params
if self.b * self.r > num_perm:
raise ValueError(
"The product of b and r in params is "
"{} * {} = {} -- it must be less than num_perm {}. "
"Did you forget to specify num_perm?".format(
self.b, self.r, self.b * self.r, num_perm
)
)
else:
false_positive_weight, false_negative_weight = weights
self.b, self.r = _optimal_param(
threshold, num_perm, false_positive_weight, false_negative_weight
)
if self.b < 2:
raise ValueError("The number of bands are too small (b < 2)")
self.prepickle = (
storage_config["type"] == "redis" if not prepickle else prepickle
)
self.hashfunc = hashfunc
if hashfunc:
self._H = self._hashed_byteswap
else:
self._H = self._byteswap
basename = storage_config.get("basename", _random_name(11))
self.hashtables = [
unordered_storage(
storage_config,
name=b"".join([basename, b"_bucket_", struct.pack(">H", i)]),
)
for i in range(self.b)
]
self.hashranges = [(i * self.r, (i + 1) * self.r) for i in range(self.b)]
self.keys = ordered_storage(storage_config, name=b"".join([basename, b"_keys"]))
def close(self):
"""Closes the internal connections."""
if isinstance(self.hashtables[0], DiskCacheListStorage):
for ht in self.hashtables:
ht.close()
self.keys.close()
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
import numpy as np
from pydantic import Field
from rich.progress import track
from typing_extensions import override
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import GlobalStep, StepInput
if TYPE_CHECKING:
from distilabel.typing import StepOutput
class EmbeddingDedup(GlobalStep):
"""Deduplicates text using embeddings.
`EmbeddingDedup` is a Step that detects near-duplicates in datasets, using
embeddings to compare the similarity between the texts. The typical workflow with this step
would include having a dataset with embeddings precomputed, and then (possibly using the
`FaissNearestNeighbour`) using the `nn_indices` and `nn_scores`, determine the texts that
are duplicate.
Attributes:
threshold: the threshold to consider 2 examples as duplicates.
It's dependent on the type of index that was used to generate the embeddings.
For example, if the embeddings were generated using cosine similarity, a threshold
of `0.9` would make all the texts with a cosine similarity above the value
duplicates. Higher values detect less duplicates in such an index, but that should
be taken into account when building it. Defaults to `0.9`.
Runtime Parameters:
- `threshold`: the threshold to consider 2 examples as duplicates.
Input columns:
- nn_indices (`List[int]`): a list containing the indices of the `k` nearest neighbours
in the inputs for the row.
- nn_scores (`List[float]`): a list containing the score or distance to each `k`
nearest neighbour in the inputs.
Output columns:
- keep_row_after_embedding_filtering (`bool`): boolean indicating if the piece `text` is
not a duplicate i.e. this text should be kept.
Categories:
- filtering
Examples:
Deduplicate a list of texts using embedding information:
```python
from distilabel.pipeline import Pipeline
from distilabel.steps import EmbeddingDedup
from distilabel.steps import LoadDataFromDicts
with Pipeline() as pipeline:
data = LoadDataFromDicts(
data=[
{
"persona": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.",
"embedding": [
0.018477669046149742,
-0.03748236608841726,
0.001919870620352492,
0.024918478063770535,
0.02348063521315178,
0.0038251285566308375,
-0.01723884983037716,
0.02881971942372201,
],
"nn_indices": [0, 1],
"nn_scores": [
0.9164746999740601,
0.782106876373291,
],
},
{
"persona": "A music teacher or instructor focused on theoretical and practical piano lessons.",
"embedding": [
-0.0023464179614082125,
-0.07325472251663565,
-0.06058678419516501,
-0.02100326928586996,
-0.013462744792362657,
0.027368447064244242,
-0.003916070100455717,
0.01243614518480423,
],
"nn_indices": [0, 2],
"nn_scores": [
0.7552462220191956,
0.7261884808540344,
],
},
{
"persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
"embedding": [
-0.01630817942328242,
-0.023760151552345232,
-0.014249650090627883,
-0.005713686451446624,
-0.016033059279131567,
0.0071440908501058786,
-0.05691099643425161,
0.01597412704817784,
],
"nn_indices": [1, 2],
"nn_scores": [
0.8107735514640808,
0.7172299027442932,
],
},
],
batch_size=batch_size,
)
# In general you should do something like this before the deduplication step, to obtain the
# `nn_indices` and `nn_scores`. In this case the embeddings are already normalized, so there's
# no need for it.
# nn = FaissNearestNeighbour(
# k=30,
# metric_type=faiss.METRIC_INNER_PRODUCT,
# search_batch_size=50,
# train_size=len(dataset), # The number of embeddings to use for training
# string_factory="IVF300_HNSW32,Flat" # To use an index (optional, maybe required for big datasets)
# )
# Read more about the `string_factory` here:
# https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
embedding_dedup = EmbeddingDedup(
threshold=0.8,
input_batch_size=batch_size,
)
data >> embedding_dedup
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False)
ds = distiset["default"]["train"]
# Filter out the duplicates
ds_dedup = ds.filter(lambda x: x["keep_row_after_embedding_filtering"])
```
"""
threshold: Optional[RuntimeParameter[float]] = Field(
default=0.9,
description="The threshold to consider 2 examples as duplicates. It's dependent "
"on the type of index that was used to generate the embeddings. For example, if "
"the embeddings were generated using cosine similarity, a threshold of `0.9` "
"would make all the texts with a cosine similarity above the value duplicates. "
"Higher values detect less duplicates in such an index, but that should be "
"taken into account when building it.",
)
@property
def inputs(self) -> List[str]:
return ["nn_scores", "nn_indices"]
@property
def outputs(self) -> List[str]:
return ["keep_row_after_embedding_filtering"]
@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
rows_to_remove = set()
for input in track(inputs, description="Running Embedding deduplication..."):
input["keep_row_after_embedding_filtering"] = True
indices_scores = np.array(input["nn_scores"]) > self.threshold
indices = np.array(input["nn_indices"])[indices_scores]
if len(indices) > 0: # If there are any rows found over the threshold
rows_to_remove.update(list(indices))
# Remove duplicates and get the list of rows to remove
for idx in rows_to_remove:
inputs[idx]["keep_row_after_embedding_filtering"] = False
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import uuid
from functools import partial
from itertools import tee
from typing import (
TYPE_CHECKING,
Callable,
Iterable,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
from pydantic import PrivateAttr
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from datasketch import MinHash, MinHashLSH
from distilabel.typing import StepOutput
# Copied from: https://github.com/huggingface/datatrove/blob/main/src/datatrove/utils/text.py#L89C1-L95C65
def ngrams(sequence: Iterable[str], n: int) -> Iterator[Tuple[str, ...]]:
iterables = tee(sequence, n)
for i, sub_iterable in enumerate(iterables): # For each window,
for _ in range(i): # iterate through every order of ngrams
next(sub_iterable, None) # generate the ngrams within the window.
return zip(*iterables) # Unpack and flattens the iterables.
def tokenized_on_words(texts: Iterable[str]) -> List[Set[bytes]]:
"""Tokenizes a list of texts into words, using `nltk.word_tokenize`.
Args:
texts: List of documents to be tokenized.
Returns:
List with the set of tokens for each document.
"""
from nltk.tokenize import word_tokenize
return [{w.encode("utf-8") for w in word_tokenize(text)} for text in texts]
def tokenize_on_ngrams(texts: Iterable[str], n: int = 1) -> List[Set[bytes]]:
"""Tokenizes a list of texts into ngrams, and returns the set of them as bytes.
Args:
texts: List of documents to be tokenized.
n: The size of the ngrams, defaults to 1 (single letters).
Returns:
List with the set of tokens for each document.
"""
return [
{"".join(ngram).encode("utf-8") for ngram in ngrams(text, n=n)}
for text in texts
]
class MinHashDedup(Step):
"""Deduplicates text using `MinHash` and `MinHashLSH`.
`MinHashDedup` is a Step that detects near-duplicates in datasets. The idea roughly translates
to the following steps:
1. Tokenize the text into words or ngrams.
2. Create a `MinHash` for each text.
3. Store the `MinHashes` in a `MinHashLSH`.
4. Check if the `MinHash` is already in the `LSH`, if so, it is a duplicate.
Attributes:
num_perm: the number of permutations to use. Defaults to `128`.
seed: the seed to use for the MinHash. Defaults to `1`.
tokenizer: the tokenizer to use. Available ones are `words` or `ngrams`.
If `words` is selected, it tokenizes the text into words using nltk's
word tokenizer. `ngram` estimates the ngrams (together with the size
`n`). Defaults to `words`.
n: the size of the ngrams to use. Only relevant if `tokenizer="ngrams"`. Defaults to `5`.
threshold: the threshold to consider two MinHashes as duplicates.
Values closer to 0 detect more duplicates. Defaults to `0.9`.
storage: the storage to use for the LSH. Can be `dict` to store the index
in memory, or `disk`. Keep in mind, `disk` is an experimental feature
not defined in `datasketch`, that is based on DiskCache's `Index` class.
It should work as a `dict`, but backed by disk, but depending on the system
it can be slower. Defaults to `dict`.
Input columns:
- text (`str`): the texts to be filtered.
Output columns:
- keep_row_after_minhash_filtering (`bool`): boolean indicating if the piece `text` is
not a duplicate i.e. this text should be kept.
Categories:
- filtering
References:
- [`datasketch documentation`](https://ekzhu.github.io/datasketch/lsh.html)
- [Identifying and Filtering Near-Duplicate Documents](https://cs.brown.edu/courses/cs253/papers/nearduplicate.pdf)
- [Diskcache's Index](https://grantjenks.com/docs/diskcache/api.html#diskcache.Index)
Examples:
Deduplicate a list of texts using MinHash and MinHashLSH:
```python
from distilabel.pipeline import Pipeline
from distilabel.steps import MinHashDedup
from distilabel.steps import LoadDataFromDicts
with Pipeline() as pipeline:
ds_size = 1000
batch_size = 500 # Bigger batch sizes work better for this step
data = LoadDataFromDicts(
data=[
{"text": "This is a test document."},
{"text": "This document is a test."},
{"text": "Test document for duplication."},
{"text": "Document for duplication test."},
{"text": "This is another unique document."},
]
* (ds_size // 5),
batch_size=batch_size,
)
minhash_dedup = MinHashDedup(
tokenizer="words",
threshold=0.9, # lower values will increase the number of duplicates
storage="dict", # or "disk" for bigger datasets
)
data >> minhash_dedup
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False)
ds = distiset["default"]["train"]
# Filter out the duplicates
ds_dedup = ds.filter(lambda x: x["keep_row_after_minhash_filtering"])
```
"""
num_perm: int = 128
seed: int = 1
tokenizer: Literal["words", "ngrams"] = "words"
n: Optional[int] = 5
threshold: float = 0.9
storage: Literal["dict", "disk"] = "dict"
_hasher: Union["MinHash", None] = PrivateAttr(None)
_tokenizer: Union[Callable, None] = PrivateAttr(None)
_lhs: Union["MinHashLSH", None] = PrivateAttr(None)
def load(self) -> None:
super().load()
if not importlib.import_module("datasketch"):
raise ImportError(
"`datasketch` is needed to deduplicate with MinHash, but is not installed. "
"Please install it using `pip install 'distilabel[minhash]'`."
)
from datasketch import MinHash
from distilabel.steps.filtering._datasketch import MinHashLSH
self._hasher = MinHash.bulk
self._lsh = MinHashLSH(
num_perm=self.num_perm,
threshold=self.threshold,
storage_config={"type": self.storage},
)
if self.tokenizer == "words":
if not importlib.import_module("nltk"):
raise ImportError(
"`nltk` is needed to tokenize based on words, but is not installed. "
"Please install it using `pip install 'distilabel[minhash]'`. Then run `nltk.download('punkt_tab')`."
)
self._tokenizer = tokenized_on_words
else:
self._tokenizer = partial(tokenize_on_ngrams, n=self.n)
def unload(self) -> None:
super().unload()
# In case of LSH being stored in disk, we need to close the file.
if self.storage == "disk":
self._lsh.close()
@property
def inputs(self) -> List[str]:
return ["text"]
@property
def outputs(self) -> List[str]:
return ["keep_row_after_minhash_filtering"]
def process(self, inputs: StepInput) -> "StepOutput":
tokenized_texts = []
for input in inputs:
tokenized_texts.append(self._tokenizer([input[self.inputs[0]]])[0])
minhashes = self._hasher(
tokenized_texts, num_perm=self.num_perm, seed=self.seed
)
for input, minhash in zip(inputs, minhashes):
# Check if the text is already in the LSH index
if self._lsh.query(minhash):
input["keep_row_after_minhash_filtering"] = False
else:
self._lsh.insert(str(uuid.uuid4()), minhash)
input["keep_row_after_minhash_filtering"] = True
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from distilabel.typing import StepColumns, StepOutput
class ConversationTemplate(Step):
"""Generate a conversation template from an instruction and a response.
Input columns:
- instruction (`str`): The instruction to be used in the conversation.
- response (`str`): The response to be used in the conversation.
Output columns:
- conversation (`ChatType`): The conversation template.
Categories:
- format
- chat
- template
Examples:
Create a conversation from an instruction and a response:
```python
from distilabel.steps import ConversationTemplate
conv_template = ConversationTemplate()
conv_template.load()
result = next(
conv_template.process(
[
{
"instruction": "Hello",
"response": "Hi",
}
],
)
)
# >>> result
# [{'instruction': 'Hello', 'response': 'Hi', 'conversation': [{'role': 'user', 'content': 'Hello'}, {'role': 'assistant', 'content': 'Hi'}]}]
```
"""
@property
def inputs(self) -> "StepColumns":
"""The instruction and response."""
return ["instruction", "response"]
@property
def outputs(self) -> "StepColumns":
"""The conversation template."""
return ["conversation"]
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Generate a conversation template from an instruction and a response.
Args:
inputs: The input data.
Yields:
The input data with the conversation template.
"""
for input in inputs:
input["conversation"] = [
{"role": "user", "content": input["instruction"]},
{"role": "assistant", "content": input["response"]},
]
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, List
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from distilabel.typing import StepColumns, StepOutput
class FormatTextGenerationDPO(Step):
"""Format the output of your LLMs for Direct Preference Optimization (DPO).
`FormatTextGenerationDPO` is a `Step` that formats the output of the combination of a `TextGeneration`
task with a preference `Task` i.e. a task generating `ratings`, so that those are used to rank the
existing generations and provide the `chosen` and `rejected` generations based on the `ratings`.
Use this step to transform the output of a combination of a `TextGeneration` + a preference task such as
`UltraFeedback` following the standard formatting from frameworks such as `axolotl` or `alignment-handbook`.
Note:
The `generations` column should contain at least two generations, the `ratings` column should
contain the same number of ratings as generations.
Input columns:
- system_prompt (`str`, optional): The system prompt used within the `LLM` to generate the
`generations`, if available.
- instruction (`str`): The instruction used to generate the `generations` with the `LLM`.
- generations (`List[str]`): The generations produced by the `LLM`.
- generation_models (`List[str]`, optional): The model names used to generate the `generations`,
only available if the `model_name` from the `TextGeneration` task/s is combined into a single
column named this way, otherwise, it will be ignored.
- ratings (`List[float]`): The ratings for each of the `generations`, produced by a preference
task such as `UltraFeedback`.
Output columns:
- prompt (`str`): The instruction used to generate the `generations` with the `LLM`.
- prompt_id (`str`): The `SHA256` hash of the `prompt`.
- chosen (`List[Dict[str, str]]`): The `chosen` generation based on the `ratings`.
- chosen_model (`str`, optional): The model name used to generate the `chosen` generation,
if the `generation_models` are available.
- chosen_rating (`float`): The rating of the `chosen` generation.
- rejected (`List[Dict[str, str]]`): The `rejected` generation based on the `ratings`.
- rejected_model (`str`, optional): The model name used to generate the `rejected` generation,
if the `generation_models` are available.
- rejected_rating (`float`): The rating of the `rejected` generation.
Categories:
- format
- text-generation
- preference
- instruction
- generations
Examples:
Format your dataset for DPO fine tuning:
```python
from distilabel.steps import FormatTextGenerationDPO
format_dpo = FormatTextGenerationDPO()
format_dpo.load()
# NOTE: Both "system_prompt" and "generation_models" can be added optionally.
result = next(
format_dpo.process(
[
{
"instruction": "What's 2+2?",
"generations": ["4", "5", "6"],
"ratings": [1, 0, -1],
}
]
)
)
# >>> result
# [
# { 'instruction': "What's 2+2?",
# 'generations': ['4', '5', '6'],
# 'ratings': [1, 0, -1],
# 'prompt': "What's 2+2?",
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'chosen_rating': 1,
# 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
# 'rejected_rating': -1
# }
# ]
```
"""
@property
def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `instruction`, `generations`,
and `ratings`."""
return {
"system_prompt": False,
"instruction": True,
"generations": True,
"generation_models": False,
"ratings": True,
}
@property
def optional_inputs(self) -> List[str]:
"""List of optional inputs, which are not required by the `Step` but used if available,
which in this case are: `system_prompt`, and `generation_models`."""
return ["system_prompt", "generation_models"]
@property
def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `chosen`,
`chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`. Both
the `chosen_model` and `rejected_model` being optional and only used if `generation_models`
is available.
Reference:
- Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
"""
return [
"prompt",
"prompt_id",
"chosen",
"chosen_model",
"chosen_rating",
"rejected",
"rejected_model",
"rejected_rating",
]
def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
"""The `process` method formats the received `StepInput` or list of `StepInput`
according to the DPO formatting standard.
Args:
*inputs: A list of `StepInput` to be combined.
Yields:
A `StepOutput` with batches of formatted `StepInput` following the DPO standard.
"""
for input in inputs:
for item in input:
messages = [
{"role": "user", "content": item["instruction"]}, # type: ignore
]
if (
"system_prompt" in item
and isinstance(item["system_prompt"], str) # type: ignore
and len(item["system_prompt"]) > 0 # type: ignore
):
messages.insert(
0,
{"role": "system", "content": item["system_prompt"]}, # type: ignore
)
item["prompt"] = item["instruction"]
item["prompt_id"] = hashlib.sha256(
item["prompt"].encode("utf-8") # type: ignore
).hexdigest()
chosen_idx = max(enumerate(item["ratings"]), key=lambda x: x[1])[0]
item["chosen"] = messages + [
{
"role": "assistant",
"content": item["generations"][chosen_idx],
}
]
if "generation_models" in item:
item["chosen_model"] = item["generation_models"][chosen_idx]
item["chosen_rating"] = item["ratings"][chosen_idx]
rejected_idx = min(enumerate(item["ratings"]), key=lambda x: x[1])[0]
item["rejected"] = messages + [
{
"role": "assistant",
"content": item["generations"][rejected_idx],
}
]
if "generation_models" in item:
item["rejected_model"] = item["generation_models"][rejected_idx]
item["rejected_rating"] = item["ratings"][rejected_idx]
yield input
class FormatChatGenerationDPO(Step):
"""Format the output of a combination of a `ChatGeneration` + a preference task for Direct Preference Optimization (DPO).
`FormatChatGenerationDPO` is a `Step` that formats the output of the combination of a `ChatGeneration`
task with a preference `Task` i.e. a task generating `ratings` such as `UltraFeedback` following the standard
formatting from frameworks such as `axolotl` or `alignment-handbook`., so that those are used to rank the
existing generations and provide the `chosen` and `rejected` generations based on the `ratings`.
Note:
The `messages` column should contain at least one message from the user, the `generations`
column should contain at least two generations, the `ratings` column should contain the same
number of ratings as generations.
Input columns:
- messages (`List[Dict[str, str]]`): The conversation messages.
- generations (`List[str]`): The generations produced by the `LLM`.
- generation_models (`List[str]`, optional): The model names used to generate the `generations`,
only available if the `model_name` from the `ChatGeneration` task/s is combined into a single
column named this way, otherwise, it will be ignored.
- ratings (`List[float]`): The ratings for each of the `generations`, produced by a preference
task such as `UltraFeedback`.
Output columns:
- prompt (`str`): The user message used to generate the `generations` with the `LLM`.
- prompt_id (`str`): The `SHA256` hash of the `prompt`.
- chosen (`List[Dict[str, str]]`): The `chosen` generation based on the `ratings`.
- chosen_model (`str`, optional): The model name used to generate the `chosen` generation,
if the `generation_models` are available.
- chosen_rating (`float`): The rating of the `chosen` generation.
- rejected (`List[Dict[str, str]]`): The `rejected` generation based on the `ratings`.
- rejected_model (`str`, optional): The model name used to generate the `rejected` generation,
if the `generation_models` are available.
- rejected_rating (`float`): The rating of the `rejected` generation.
Categories:
- format
- chat-generation
- preference
- messages
- generations
Examples:
Format your dataset for DPO fine tuning:
```python
from distilabel.steps import FormatChatGenerationDPO
format_dpo = FormatChatGenerationDPO()
format_dpo.load()
# NOTE: "generation_models" can be added optionally.
result = next(
format_dpo.process(
[
{
"messages": [{"role": "user", "content": "What's 2+2?"}],
"generations": ["4", "5", "6"],
"ratings": [1, 0, -1],
}
]
)
)
# >>> result
# [
# {
# 'messages': [{'role': 'user', 'content': "What's 2+2?"}],
# 'generations': ['4', '5', '6'],
# 'ratings': [1, 0, -1],
# 'prompt': "What's 2+2?",
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'chosen_rating': 1,
# 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
# 'rejected_rating': -1
# }
# ]
```
"""
@property
def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `messages`, `generations`,
and `ratings`."""
return ["messages", "generations", "ratings"]
@property
def optional_inputs(self) -> List[str]:
"""List of optional inputs, which are not required by the `Step` but used if available,
which in this case is: `generation_models`."""
return ["generation_models"]
@property
def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `chosen`,
`chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`. Both
the `chosen_model` and `rejected_model` being optional and only used if `generation_models`
is available.
Reference:
- Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
"""
return [
"prompt",
"prompt_id",
"chosen",
"chosen_model",
"chosen_rating",
"rejected",
"rejected_model",
"rejected_rating",
]
def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
"""The `process` method formats the received `StepInput` or list of `StepInput`
according to the DPO formatting standard.
Args:
*inputs: A list of `StepInput` to be combined.
Yields:
A `StepOutput` with batches of formatted `StepInput` following the DPO standard.
"""
for input in inputs:
for item in input:
item["prompt"] = next(
(
turn["content"]
for turn in item["messages"]
if turn["role"] == "user"
),
None,
)
item["prompt_id"] = hashlib.sha256(
item["prompt"].encode("utf-8") # type: ignore
).hexdigest()
chosen_idx = max(enumerate(item["ratings"]), key=lambda x: x[1])[0]
item["chosen"] = item["messages"] + [
{
"role": "assistant",
"content": item["generations"][chosen_idx],
}
]
if "generation_models" in item:
item["chosen_model"] = item["generation_models"][chosen_idx]
item["chosen_rating"] = item["ratings"][chosen_idx]
rejected_idx = min(enumerate(item["ratings"]), key=lambda x: x[1])[0]
item["rejected"] = item["messages"] + [
{
"role": "assistant",
"content": item["generations"][rejected_idx],
}
]
if "generation_models" in item:
item["rejected_model"] = item["generation_models"][rejected_idx]
item["rejected_rating"] = item["ratings"][rejected_idx]
yield input
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, List
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from distilabel.typing import StepColumns, StepOutput
class FormatTextGenerationSFT(Step):
"""Format the output of a `TextGeneration` task for Supervised Fine-Tuning (SFT).
`FormatTextGenerationSFT` is a `Step` that formats the output of a `TextGeneration` task for
Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as `axolotl`
or `alignment-handbook`. The output of the `TextGeneration` task is formatted into a chat-like
conversation with the `instruction` as the user message and the `generation` as the assistant
message. Optionally, if the `system_prompt` is available, it is included as the first message
in the conversation.
Input columns:
- system_prompt (`str`, optional): The system prompt used within the `LLM` to generate the
`generation`, if available.
- instruction (`str`): The instruction used to generate the `generation` with the `LLM`.
- generation (`str`): The generation produced by the `LLM`.
Output columns:
- prompt (`str`): The instruction used to generate the `generation` with the `LLM`.
- prompt_id (`str`): The `SHA256` hash of the `prompt`.
- messages (`List[Dict[str, str]]`): The chat-like conversation with the `instruction` as
the user message and the `generation` as the assistant message.
Categories:
- format
- text-generation
- instruction
- generation
Examples:
Format your dataset for SFT fine tuning:
```python
from distilabel.steps import FormatTextGenerationSFT
format_sft = FormatTextGenerationSFT()
format_sft.load()
# NOTE: "system_prompt" can be added optionally.
result = next(
format_sft.process(
[
{
"instruction": "What's 2+2?",
"generation": "4"
}
]
)
)
# >>> result
# [
# {
# 'instruction': 'What's 2+2?',
# 'generation': '4',
# 'prompt': 'What's 2+2?',
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'messages': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}]
# }
# ]
```
"""
@property
def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `instruction`, and `generation`."""
return {
"system_prompt": False,
"instruction": True,
"generation": True,
}
@property
def optional_inputs(self) -> List[str]:
"""List of optional inputs, which are not required by the `Step` but used if available,
which in this case is: `system_prompt`."""
return ["system_prompt"]
@property
def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `messages`.
Reference:
- Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
"""
return ["prompt", "prompt_id", "messages"]
def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
"""The `process` method formats the received `StepInput` or list of `StepInput`
according to the SFT formatting standard.
Args:
*inputs: A list of `StepInput` to be combined.
Yields:
A `StepOutput` with batches of formatted `StepInput` following the SFT standard.
"""
for input in inputs:
for item in input:
item["prompt"] = item["instruction"]
item["prompt_id"] = hashlib.sha256(
item["prompt"].encode("utf-8") # type: ignore
).hexdigest()
item["messages"] = [
{"role": "user", "content": item["instruction"]}, # type: ignore
{"role": "assistant", "content": item["generation"]}, # type: ignore
]
if (
"system_prompt" in item
and isinstance(item["system_prompt"], str) # type: ignore
and len(item["system_prompt"]) > 0 # type: ignore
):
item["messages"].insert(
0,
{"role": "system", "content": item["system_prompt"]}, # type: ignore
)
yield input
class FormatChatGenerationSFT(Step):
"""Format the output of a `ChatGeneration` task for Supervised Fine-Tuning (SFT).
`FormatChatGenerationSFT` is a `Step` that formats the output of a `ChatGeneration` task for
Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as `axolotl`
or `alignment-handbook`. The output of the `ChatGeneration` task is formatted into a chat-like
conversation with the `instruction` as the user message and the `generation` as the assistant
message. Optionally, if the `system_prompt` is available, it is included as the first message
in the conversation.
Input columns:
- system_prompt (`str`, optional): The system prompt used within the `LLM` to generate the
`generation`, if available.
- instruction (`str`): The instruction used to generate the `generation` with the `LLM`.
- generation (`str`): The generation produced by the `LLM`.
Output columns:
- prompt (`str`): The instruction used to generate the `generation` with the `LLM`.
- prompt_id (`str`): The `SHA256` hash of the `prompt`.
- messages (`List[Dict[str, str]]`): The chat-like conversation with the `instruction` as
the user message and the `generation` as the assistant message.
Categories:
- format
- chat-generation
- instruction
- generation
Examples:
Format your dataset for SFT:
```python
from distilabel.steps import FormatChatGenerationSFT
format_sft = FormatChatGenerationSFT()
format_sft.load()
# NOTE: "system_prompt" can be added optionally.
result = next(
format_sft.process(
[
{
"messages": [{"role": "user", "content": "What's 2+2?"}],
"generation": "4"
}
]
)
)
# >>> result
# [
# {
# 'messages': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'generation': '4',
# 'prompt': 'What's 2+2?',
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# }
# ]
```
"""
@property
def inputs(self) -> "StepColumns":
"""List of inputs required by the `Step`, which in this case are: `instruction`, and `generation`."""
return ["messages", "generation"]
@property
def outputs(self) -> "StepColumns":
"""List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `messages`.
Reference:
- Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
"""
return ["prompt", "prompt_id", "messages"]
def process(self, *inputs: StepInput) -> "StepOutput": # type: ignore
"""The `process` method formats the received `StepInput` or list of `StepInput`
according to the SFT formatting standard.
Args:
*inputs: A list of `StepInput` to be combined.
Yields:
A `StepOutput` with batches of formatted `StepInput` following the SFT standard.
"""
for input in inputs:
for item in input:
item["prompt"] = next(
(
turn["content"]
for turn in item["messages"]
if turn["role"] == "user"
),
None,
)
item["prompt_id"] = hashlib.sha256(
item["prompt"].encode("utf-8") # type: ignore
).hexdigest()
item["messages"] = item["messages"] + [
{"role": "assistant", "content": item["generation"]}, # type: ignore
]
yield input
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List
from pydantic import Field
from typing_extensions import override
from distilabel.steps.base import GeneratorStep
if TYPE_CHECKING:
from distilabel.typing import GeneratorStepOutput
class LoadDataFromDicts(GeneratorStep):
"""Loads a dataset from a list of dictionaries.
`GeneratorStep` that loads a dataset from a list of dictionaries and yields it in
batches.
Attributes:
data: The list of dictionaries to load the data from.
Runtime parameters:
- `batch_size`: The batch size to use when processing the data.
Output columns:
- dynamic (based on the keys found on the first dictionary of the list): The columns
of the dataset.
Categories:
- load
Examples:
Load data from a list of dictionaries:
```python
from distilabel.steps import LoadDataFromDicts
loader = LoadDataFromDicts(
data=[{"instruction": "What are 2+2?"}] * 5,
batch_size=2
)
loader.load()
result = next(loader.process())
# >>> result
# ([{'instruction': 'What are 2+2?'}, {'instruction': 'What are 2+2?'}], False)
```
"""
data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True)
@override
def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
"""Yields batches from a list of dictionaries.
Args:
offset: The offset to start the generation from. Defaults to `0`.
Yields:
A list of Python dictionaries as read from the inputs (propagated in batches)
and a flag indicating whether the yield batch is the last one.
"""
if offset:
self.data = self.data[offset:]
while self.data:
batch = self.data[: self.batch_size]
self.data = self.data[self.batch_size :]
yield (
batch,
True if len(self.data) == 0 else False,
)
@property
def outputs(self) -> List[str]:
"""Returns a list of strings with the names of the columns that the step will generate."""
return list(self.data[0].keys())
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, List
from pydantic import Field
from typing_extensions import override
from distilabel.steps.base import GeneratorStep
if TYPE_CHECKING:
from distilabel.steps.base import GeneratorStepOutput
class DataSampler(GeneratorStep):
"""Step to sample from a dataset.
`GeneratorStep` that samples from a dataset and yields it in batches.
This step is useful when you have a pipeline that can benefit from using examples
in the prompts for example as few-shot learning, that can be changing on each row.
For example, you can pass a list of dictionaries with N examples and generate M samples
from it (assuming you have another step loading data, this M should have the same size
as the data being loaded in that step). The size S argument is the number of samples per
row generated, so each example would contain S examples to be used as examples.
Attributes:
data: The list of dictionaries to sample from.
size: Number of samples per example. For example in a few-shot learning scenario,
the number of few-shot examples that will be generated per example. Defaults to 2.
samples: Number of examples that will be generated by the step in total.
If used with another loader step, this should be the same as the number
of samples in the loader step. Defaults to 100.
Output columns:
- dynamic (based on the keys found on the first dictionary of the list): The columns
of the dataset.
Categories:
- load
Examples:
Sample data from a list of dictionaries:
```python
from distilabel.steps import DataSampler
sampler = DataSampler(
data=[{"sample": f"sample {i}"} for i in range(30)],
samples=10,
size=2,
batch_size=4
)
sampler.load()
result = next(sampler.process())
# >>> result
# ([{'sample': ['sample 7', 'sample 0']}, {'sample': ['sample 2', 'sample 21']}, {'sample': ['sample 17', 'sample 12']}, {'sample': ['sample 2', 'sample 14']}], False)
```
Pipeline with a loader and a sampler combined in a single stream:
```python
from datasets import load_dataset
from distilabel.steps import LoadDataFromDicts, DataSampler
from distilabel.steps.tasks.apigen.utils import PrepareExamples
from distilabel.pipeline import Pipeline
ds = (
load_dataset("Salesforce/xlam-function-calling-60k", split="train")
.shuffle(seed=42)
.select(range(500))
.to_list()
)
data = [
{
"func_name": "final_velocity",
"func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
},
{
"func_name": "permutation_count",
"func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
},
{
"func_name": "getdivision",
"func_desc": "Divides two numbers by making an API call to a division service.",
},
]
with Pipeline(name="APIGenPipeline") as pipeline:
loader_seeds = LoadDataFromDicts(data=data)
sampler = DataSampler(
data=ds,
size=2,
samples=len(data),
batch_size=8,
)
prep_examples = PrepareExamples()
sampler >> prep_examples
(
[loader_seeds, prep_examples]
>> combine_steps
)
# Now we have a single stream of data with the loader and the sampler data
```
"""
data: List[Dict[str, Any]] = Field(default_factory=list, exclude=True)
size: int = Field(
default=2,
description=(
"Number of samples per example. For example in a few-shot learning scenario, the number "
"of few-shot examples that will be generated per example."
),
)
samples: int = Field(
default=100,
description=(
"Number of examples that will be generated by the step in total. "
"If used with another loader step, this should be the same as the number of "
"samples in the loader step."
),
)
@override
def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
"""Yields batches from a list of dictionaries.
Args:
offset: The offset to start the generation from. Defaults to `0`.
Yields:
A list of Python dictionaries as read from the inputs (propagated in batches)
and a flag indicating whether the yield batch is the last one.
"""
total_samples = 0
while total_samples < self.samples:
batch = []
bs = min(self.batch_size, self.samples - total_samples)
for _ in range(self.batch_size):
choices = random.choices(self.data, k=self.size)
choices = self._transform_data(choices)
batch.extend(choices)
total_samples += bs
batch = list(islice(batch, bs))
yield (batch, True if total_samples >= self.samples else False)
batch = []
@staticmethod
def _transform_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not data:
return []
result = {key: [] for key in data[0].keys()}
for item in data:
for key, value in item.items():
result[key].append(value)
return [result]
@property
def outputs(self) -> List[str]:
return list(self.data[0].keys())
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from collections import defaultdict
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from datasets import (
Dataset,
DatasetInfo,
IterableDataset,
get_dataset_infos,
load_dataset,
load_from_disk,
)
from pydantic import Field, PrivateAttr
from upath import UPath
from distilabel.distiset import Distiset
from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import GeneratorStep
if TYPE_CHECKING:
from distilabel.typing import GeneratorStepOutput
T = TypeVar("T")
# To avoid using repo_id in LoadDataFromFileSystem:
# https://github.com/pydantic/pydantic/discussions/7076#discussioncomment-6699138
ExcludedField = Annotated[T, Field(exclude=True)]
class LoadDataFromHub(GeneratorStep):
"""Loads a dataset from the Hugging Face Hub.
`GeneratorStep` that loads a dataset from the Hugging Face Hub using the `datasets`
library.
Attributes:
repo_id: The Hugging Face Hub repository ID of the dataset to load.
split: The split of the dataset to load.
config: The configuration of the dataset to load. This is optional and only needed
if the dataset has multiple configurations.
Runtime parameters:
- `batch_size`: The batch size to use when processing the data.
- `repo_id`: The Hugging Face Hub repository ID of the dataset to load.
- `split`: The split of the dataset to load. Defaults to 'train'.
- `config`: The configuration of the dataset to load. This is optional and only
needed if the dataset has multiple configurations.
- `revision`: The revision of the dataset to load. Defaults to the latest revision.
- `streaming`: Whether to load the dataset in streaming mode or not. Defaults to
`False`.
- `num_examples`: The number of examples to load from the dataset.
By default will load all examples.
- `storage_options`: Key/value pairs to be passed on to the file-system backend, if any.
Defaults to `None`.
Output columns:
- dynamic (`all`): The columns that will be generated by this step, based on the
datasets loaded from the Hugging Face Hub.
Categories:
- load
Examples:
Load data from a dataset in Hugging Face Hub:
```python
from distilabel.steps import LoadDataFromHub
loader = LoadDataFromHub(
repo_id="distilabel-internal-testing/instruction-dataset-mini",
split="test",
batch_size=2
)
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'prompt': 'Arianna has 12...', False)
```
"""
repo_id: RuntimeParameter[str] = Field(
default=None,
description="The Hugging Face Hub repository ID of the dataset to load.",
)
split: RuntimeParameter[str] = Field(
default="train",
description="The split of the dataset to load. Defaults to 'train'.",
)
config: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The configuration of the dataset to load. This is optional and only"
" needed if the dataset has multiple configurations.",
)
revision: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The revision of the dataset to load. Defaults to the latest revision.",
)
streaming: RuntimeParameter[bool] = Field(
default=False,
description="Whether to load the dataset in streaming mode or not. Defaults to False.",
)
num_examples: Optional[RuntimeParameter[int]] = Field(
default=None,
description="The number of examples to load from the dataset. By default will load all examples.",
)
storage_options: Optional[Dict[str, Any]] = Field(
default=None,
description="The storage options to use when loading the dataset.",
)
_dataset: Union[IterableDataset, Dataset, None] = PrivateAttr(None)
def load(self) -> None:
"""Load the dataset from the Hugging Face Hub"""
super().load()
if self._dataset is not None:
# Here to simplify the functionality of distilabel.steps.generators.util.make_generator_step
return
self._dataset = load_dataset(
self.repo_id, # type: ignore
self.config,
split=self.split,
revision=self.revision,
streaming=self.streaming,
)
num_examples = self._get_dataset_num_examples()
self.num_examples = (
min(self.num_examples, num_examples) if self.num_examples else num_examples
)
if not self.streaming:
self._dataset = self._dataset.select(range(self.num_examples))
def process(self, offset: int = 0) -> "GeneratorStepOutput":
"""Yields batches from the loaded dataset from the Hugging Face Hub.
Args:
offset: The offset to start yielding the data from. Will be used during the caching
process to help skipping already processed data.
Yields:
A tuple containing a batch of rows and a boolean indicating if the batch is
the last one.
"""
num_returned_rows = 0
for batch_num, batch in enumerate(
self._dataset.iter(batch_size=self.batch_size) # type: ignore
):
if batch_num * self.batch_size < offset:
continue
transformed_batch = self._transform_batch(batch)
batch_size = len(transformed_batch)
num_returned_rows += batch_size
yield transformed_batch, num_returned_rows >= self.num_examples
@property
def outputs(self) -> List[str]:
"""The columns that will be generated by this step, based on the datasets loaded
from the Hugging Face Hub.
Returns:
The columns that will be generated by this step.
"""
return self._get_dataset_columns()
def _transform_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Transform a batch of data from the Hugging Face Hub into a list of rows.
Args:
batch: The batch of data from the Hugging Face Hub.
Returns:
A list of rows, where each row is a dictionary of column names and values.
"""
length = len(next(iter(batch.values())))
rows = []
for i in range(length):
rows.append({col: values[i] for col, values in batch.items()})
return rows
def _get_dataset_num_examples(self) -> int:
"""Get the number of examples in the dataset, based on the `split` and `config`
runtime parameters provided.
Returns:
The number of examples in the dataset.
"""
default_config = self.config
if not default_config:
default_config = list(self._dataset_info.keys())[0]
return self._dataset_info[default_config].splits[self.split].num_examples
def _get_dataset_columns(self) -> List[str]:
"""Get the columns of the dataset, based on the `config` runtime parameter provided.
Returns:
The columns of the dataset.
"""
return list(
self._dataset_info[
self.config if self.config else "default"
].features.keys()
)
@cached_property
def _dataset_info(self) -> Dict[str, DatasetInfo]:
"""Calls the Datasets Server API from Hugging Face to obtain the dataset information.
Returns:
The dataset information.
"""
try:
return get_dataset_infos(self.repo_id)
except Exception as e:
warnings.warn(
f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}",
UserWarning,
stacklevel=2,
)
ds = load_dataset(self.repo_id, config=self.config, split=self.split)
if self.config:
return ds[self.config].info
return ds.info
class LoadDataFromFileSystem(LoadDataFromHub):
"""Loads a dataset from a file in your filesystem.
`GeneratorStep` that creates a dataset from a file in the filesystem, uses Hugging Face `datasets`
library. Take a look at [Hugging Face Datasets](https://huggingface.co/docs/datasets/loading)
for more information of the supported file types.
Attributes:
data_files: The path to the file, or directory containing the files that conform
the dataset.
split: The split of the dataset to load (typically will be `train`, `test` or `validation`).
Runtime parameters:
- `batch_size`: The batch size to use when processing the data.
- `data_files`: The path to the file, or directory containing the files that conform
the dataset.
- `split`: The split of the dataset to load. Defaults to 'train'.
- `streaming`: Whether to load the dataset in streaming mode or not. Defaults to
`False`.
- `num_examples`: The number of examples to load from the dataset.
By default will load all examples.
- `storage_options`: Key/value pairs to be passed on to the file-system backend, if any.
Defaults to `None`.
- `filetype`: The expected filetype. If not provided, it will be inferred from the file extension.
For more than one file, it will be inferred from the first file.
Output columns:
- dynamic (`all`): The columns that will be generated by this step, based on the
datasets loaded from the Hugging Face Hub.
Categories:
- load
Examples:
Load data from a Hugging Face dataset in your file system:
```python
from distilabel.steps import LoadDataFromFileSystem
loader = LoadDataFromFileSystem(data_files="path/to/dataset.jsonl")
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
Specify a filetype if the file extension is not expected:
```python
from distilabel.steps import LoadDataFromFileSystem
loader = LoadDataFromFileSystem(filetype="csv", data_files="path/to/dataset.txtr")
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
Load data from a file in your cloud provider:
```python
from distilabel.steps import LoadDataFromFileSystem
loader = LoadDataFromFileSystem(
data_files="gcs://path/to/dataset",
storage_options={"project": "experiments-0001"}
)
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
Load data passing a glob pattern:
```python
from distilabel.steps import LoadDataFromFileSystem
loader = LoadDataFromFileSystem(
data_files="path/to/dataset/*.jsonl",
streaming=True
)
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
"""
data_files: RuntimeParameter[Union[str, Path]] = Field(
default=None,
description="The data files, or directory containing the data files, to generate the dataset from.",
)
filetype: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The expected filetype. If not provided, it will be inferred from the file extension.",
)
repo_id: ExcludedField[Union[str, None]] = None
def load(self) -> None:
"""Load the dataset from the file/s in disk."""
GeneratorStep.load(self)
data_path = UPath(self.data_files, storage_options=self.storage_options)
(data_files, self.filetype) = self._prepare_data_files(data_path)
self._dataset = load_dataset(
self.filetype,
data_files=data_files,
split=self.split,
streaming=self.streaming,
storage_options=self.storage_options,
)
if not self.streaming and self.num_examples:
self._dataset = self._dataset.select(range(self.num_examples))
if not self.num_examples:
if self.streaming:
# There's no better way to get the number of examples in a streaming dataset,
# load it again for the moment.
self.num_examples = len(
load_dataset(
self.filetype, data_files=self.data_files, split=self.split
)
)
else:
self.num_examples = len(self._dataset)
@staticmethod
def _prepare_data_files( # noqa: C901
data_path: UPath,
) -> Tuple[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]], str]:
"""Prepare the loading process by setting the `data_files` attribute.
Args:
data_path: The path to the data files, or directory containing the data files.
Returns:
Tuple with the data files and the filetype.
"""
def get_filetype(data_path: UPath) -> str:
filetype = data_path.suffix.lstrip(".")
if filetype == "jsonl":
filetype = "json"
return filetype
if data_path.is_file() or (
len(str(data_path.parent.glob(data_path.name))) >= 1
):
filetype = get_filetype(data_path)
data_files = str(data_path)
elif data_path.is_dir():
file_sequence = []
file_map = defaultdict(list)
for file_or_folder in data_path.iterdir():
if file_or_folder.is_file():
file_sequence.append(str(file_or_folder))
elif file_or_folder.is_dir():
for file in file_or_folder.iterdir():
file_sequence.append(str(file))
file_map[str(file_or_folder)].append(str(file))
data_files = file_sequence or file_map
# Try to obtain the filetype from any of the files, assuming all files have the same type.
if file_sequence:
filetype = get_filetype(UPath(file_sequence[0]))
else:
filetype = get_filetype(UPath(file_map[list(file_map.keys())[0]][0]))
return data_files, filetype
@property
def outputs(self) -> List[str]:
"""The columns that will be generated by this step, based on the datasets from a file
in disk.
Returns:
The columns that will be generated by this step.
"""
# We assume there are Dataset/IterableDataset, not it's ...Dict counterparts
if self._dataset is None:
self.load()
return self._dataset.column_names
class LoadDataFromDisk(LoadDataFromHub):
"""Load a dataset that was previously saved to disk.
If you previously saved your dataset using the `save_to_disk` method, or
`Distiset.save_to_disk` you can load it again to build a new pipeline using this class.
Attributes:
dataset_path: The path to the dataset or distiset.
split: The split of the dataset to load (typically will be `train`, `test` or `validation`).
config: The configuration of the dataset to load. Defaults to `default`, if there are
multiple configurations in the dataset this must be suplied or an error is raised.
Runtime parameters:
- `batch_size`: The batch size to use when processing the data.
- `dataset_path`: The path to the dataset or distiset.
- `is_distiset`: Whether the dataset to load is a `Distiset` or not. Defaults to False.
- `split`: The split of the dataset to load. Defaults to 'train'.
- `config`: The configuration of the dataset to load. Defaults to `default`, if there are
multiple configurations in the dataset this must be suplied or an error is raised.
- `num_examples`: The number of examples to load from the dataset.
By default will load all examples.
- `storage_options`: Key/value pairs to be passed on to the file-system backend, if any.
Defaults to `None`.
Output columns:
- dynamic (`all`): The columns that will be generated by this step, based on the
datasets loaded from the Hugging Face Hub.
Categories:
- load
Examples:
Load data from a Hugging Face Dataset:
```python
from distilabel.steps import LoadDataFromDisk
loader = LoadDataFromDisk(dataset_path="path/to/dataset")
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
Load data from a distilabel Distiset:
```python
from distilabel.steps import LoadDataFromDisk
# Specify the configuration to load.
loader = LoadDataFromDisk(
dataset_path="path/to/dataset",
is_distiset=True,
config="leaf_step_1"
)
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'a': 1}, {'a': 2}, {'a': 3}], True)
```
Load data from a Hugging Face Dataset or Distiset in your cloud provider:
```python
from distilabel.steps import LoadDataFromDisk
loader = LoadDataFromDisk(
dataset_path="gcs://path/to/dataset",
storage_options={"project": "experiments-0001"}
)
loader.load()
# Just like we saw with LoadDataFromDicts, the `process` method will yield batches.
result = next(loader.process())
# >>> result
# ([{'type': 'function', 'function':...', False)
```
"""
dataset_path: RuntimeParameter[Union[str, Path]] = Field(
default=None,
description="Path to the dataset or distiset.",
)
config: Optional[RuntimeParameter[str]] = Field(
default="default",
description=(
"The configuration of the dataset to load. Will default to 'default'",
" which corresponds to a distiset with a single configuration.",
),
)
is_distiset: Optional[RuntimeParameter[bool]] = Field(
default=False,
description="Whether the dataset to load is a `Distiset` or not. Defaults to False.",
)
keep_in_memory: Optional[RuntimeParameter[bool]] = Field(
default=None,
description="Whether to copy the dataset in-memory, see `datasets.Dataset.load_from_disk` "
" for more information. Defaults to `None`.",
)
split: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The split of the dataset to load. By default will load the whole Dataset/Distiset.",
)
repo_id: ExcludedField[Union[str, None]] = None
def load(self) -> None:
"""Load the dataset from the file/s in disk."""
super(GeneratorStep, self).load()
if self.is_distiset:
ds = Distiset.load_from_disk(
self.dataset_path,
keep_in_memory=self.keep_in_memory,
storage_options=self.storage_options,
)
if self.config not in ds.keys():
raise DistilabelUserError(
f"Configuration '{self.config}' not found in the Distiset, available ones"
f" are: {list(ds.keys())}. Please try changing the `config` parameter to one "
"of the available configurations.\n\n",
page="sections/how_to_guides/advanced/distiset/#using-the-distiset-dataset-object",
)
ds = ds[self.config]
else:
ds = load_from_disk(
self.dataset_path,
keep_in_memory=self.keep_in_memory,
storage_options=self.storage_options,
)
if self.split:
ds = ds[self.split]
self._dataset = ds
if self.num_examples:
self._dataset = self._dataset.select(range(self.num_examples))
else:
self.num_examples = len(self._dataset)
@property
def outputs(self) -> List[str]:
"""The columns that will be generated by this step, based on the datasets from a file
in disk.
Returns:
The columns that will be generated by this step.
"""
# We assume there are Dataset/IterableDataset, not it's ...Dict counterparts
if self._dataset is None:
self.load()
return self._dataset.column_names
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment