Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typer
from distilabel.cli.pipeline import app as pipeline_app
app = typer.Typer(name="distilabel")
app.add_typer(pipeline_app, name="pipeline")
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.cli.pipeline.app import app
__all__ = ["app"]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, List, Optional, Tuple
import typer
from typing_extensions import Annotated
RUNTIME_PARAM_REGEX = re.compile(r"(?P<key>[^.]+(?:\.[^=]+)+)=(?P<value>.+)")
app = typer.Typer(help="Commands to run and inspect Distilabel pipelines.")
ConfigOption = Annotated[
str, typer.Option(help="Path or URL to the Distilabel pipeline configuration file.")
]
def parse_runtime_param(value: str) -> Tuple[List[str], str]:
match = RUNTIME_PARAM_REGEX.match(value)
if not match:
raise typer.BadParameter(
"Runtime parameters must be in the format `key.subkey=value` or"
" `key.subkey.subsubkey=value`"
)
return match.group("key").split("."), match.group("value")
@app.command(name="run", help="Run a Distilabel pipeline.")
def run(
# `param` is `List[Tuple[Tuple[str, ...], str]]` after parsing
param: Annotated[
List[Any],
typer.Option(help="", parser=parse_runtime_param, default_factory=list),
],
config: Optional[str] = typer.Option(
None, help="Path or URL to the distilabel pipeline configuration file."
),
script: Optional[str] = typer.Option(
None,
help="URL pointing to a python script containing a distilabel pipeline.",
),
pipeline_variable_name: str = typer.Option(
default="pipeline",
help="Name of the pipeline in a script. I.e. the 'pipeline' variable in `with Pipeline(...) as pipeline:...`.",
),
ignore_cache: bool = typer.Option(
False, help="Whether to ignore the cache and re-run the pipeline from scratch."
),
repo_id: str = typer.Option(
None,
help="The Hugging Face Hub repository ID to push the resulting dataset to.",
),
commit_message: str = typer.Option(
None, help="The commit message to use when pushing the dataset."
),
private: bool = typer.Option(
False, help="Whether to make the resulting dataset private on the Hub."
),
token: str = typer.Option(
None, help="The Hugging Face Hub API token to use when pushing the dataset."
),
) -> None:
from distilabel.cli.pipeline.utils import get_pipeline, parse_runtime_parameters
if script:
if config:
typer.secho(
"Only one of `--config` or `--script` can be informed.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)
do_run = typer.prompt("This will run a remote script, are you sure? (y/n)")
if do_run.lower() != "y":
raise typer.Exit(code=0)
if not config and not script:
typer.secho(
"`--config` or `--script` must be informed.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)
try:
pipeline = get_pipeline(config or script, pipeline_name=pipeline_variable_name)
except Exception as e:
typer.secho(str(e), fg=typer.colors.RED, bold=True)
raise typer.Exit(code=1) from e
parameters = parse_runtime_parameters(param)
distiset = pipeline.run(parameters=parameters, use_cache=not ignore_cache)
if repo_id is not None:
distiset.push_to_hub(
repo_id=repo_id,
commit_message=commit_message,
private=private,
token=token,
)
@app.command(name="info", help="Get information about a Distilabel pipeline.")
def info(config: ConfigOption) -> None:
from distilabel.cli.pipeline.utils import display_pipeline_information, get_pipeline
try:
pipeline = get_pipeline(config)
display_pipeline_information(pipeline)
except Exception as e:
typer.secho(str(e), fg=typer.colors.RED, bold=True)
raise typer.Exit(code=1) from e
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import requests
import yaml
from pydantic import HttpUrl, ValidationError
from pydantic.type_adapter import TypeAdapter
from distilabel.constants import ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME
from distilabel.errors import DistilabelUserError
from distilabel.pipeline.local import Pipeline
if TYPE_CHECKING:
from rich.panel import Panel
from distilabel.pipeline.base import BasePipeline
def parse_runtime_parameters(
params: List[Tuple[List[str], str]],
) -> Dict[str, Dict[str, Any]]:
"""Parses the runtime parameters from the CLI format to the format expected by the
`Pipeline.run` method. The CLI format is a list of tuples, where the first element is
a list of keys and the second element is the value.
Args:
params: A list of tuples, where the first element is a list of keys and the
second element is the value.
Returns:
A dictionary with the runtime parameters in the format expected by the
`Pipeline.run` method.
"""
runtime_params = {}
for keys, value in params:
current = runtime_params
for i, key in enumerate(keys):
if i == len(keys) - 1:
current[key] = value
else:
current = current.setdefault(key, {})
return runtime_params
def valid_http_url(url: str) -> bool:
"""Check if the URL is a valid HTTP URL.
Args:
url: The URL to check.
Returns:
`True`, if the URL is a valid HTTP URL. `False`, otherwise.
"""
try:
TypeAdapter(HttpUrl).validate_python(url) # type: ignore
except ValidationError:
return False
return True
def _download_remote_file(url: str) -> str:
"""Downloads a file from a Hugging Face Hub repository.
Args:
url: URL of the file to download.
Returns:
The content of the file.
"""
if "huggingface.co" in url and "HF_TOKEN" in os.environ:
headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
else:
headers = None
response = requests.get(url, headers=headers)
response.raise_for_status()
return response
def get_config_from_url(url: str) -> Dict[str, Any]:
"""Loads the pipeline configuration from a URL pointing to a JSON or YAML file.
Args:
url: The URL pointing to the pipeline configuration file.
Returns:
The pipeline configuration as a dictionary.
Raises:
ValueError: If the file format is not supported.
"""
if not url.endswith((".json", ".yaml", ".yml")):
raise DistilabelUserError(
f"Unsupported file format for '{url}'. Only JSON and YAML are supported",
page="sections/how_to_guides/basic/pipeline/?h=seriali#serializing-the-pipeline",
)
response = _download_remote_file(url)
if url.endswith((".yaml", ".yml")):
content = response.content.decode("utf-8")
return yaml.safe_load(content)
return response.json()
def get_pipeline_from_url(url: str, pipeline_name: str = "pipeline") -> "BasePipeline":
"""Downloads the file to the current working directory and loads the pipeline object
from a python script.
Args:
url: The URL pointing to the python script with the pipeline definition.
pipeline_name: The name of the pipeline in the script.
I.e: `with Pipeline(...) as pipeline:...`.
Returns:
The pipeline instantiated.
Raises:
ValueError: If the file format is not supported.
"""
if not url.endswith(".py"):
raise DistilabelUserError(
f"Unsupported file format for '{url}'. It must be a python file.",
page="sections/how_to_guides/advanced/cli/#distilabel-pipeline-run",
)
response = _download_remote_file(url)
content = response.content.decode("utf-8")
script_local = Path.cwd() / Path(url).name
script_local.write_text(content)
# Add the current working directory to sys.path
sys.path.insert(0, os.getcwd())
module = importlib.import_module(str(Path(url).stem))
pipeline = getattr(module, pipeline_name, None)
if not pipeline:
raise ImportError(
f"The script must contain an object with the pipeline named: '{pipeline_name}' that can be imported"
)
return pipeline
def get_pipeline(
config_or_script: str, pipeline_name: str = "pipeline"
) -> "BasePipeline":
"""Get a pipeline from a configuration file or a remote python script.
Args:
config_or_script: The path or URL to the pipeline configuration file
or URL to a python script.
pipeline_name: The name of the pipeline in the script.
I.e: `with Pipeline(...) as pipeline:...`.
Returns:
The pipeline.
Raises:
ValueError: If the file format is not supported.
FileNotFoundError: If the configuration file does not exist.
"""
config = script = None
if config_or_script.endswith((".json", ".yaml", ".yml")):
config = config_or_script
elif config_or_script.endswith(".py"):
script = config_or_script
else:
raise DistilabelUserError(
"The file must be a valid config file or python script with a pipeline.",
page="sections/how_to_guides/advanced/cli/#distilabel-pipeline-run",
)
if valid_http_url(config_or_script):
if config:
data = get_config_from_url(config)
return Pipeline.from_dict(data)
return get_pipeline_from_url(script, pipeline_name=pipeline_name)
if not config:
raise ValueError(
f"To run a pipeline from a python script, run it as `python {script}`"
)
if Path(config).is_file():
return Pipeline.from_file(config)
raise FileNotFoundError(f"File '{config_or_script}' does not exist.")
def display_pipeline_information(pipeline: "BasePipeline") -> None:
"""Displays the pipeline information to the console.
Args:
pipeline: The pipeline.
"""
from rich.console import Console
Console().print(_build_pipeline_panel(pipeline))
def _build_pipeline_panel(pipeline: "BasePipeline") -> "Panel":
"""Builds a panel to display the information of the pipeline.
Args:
pipeline: The pipeline
Returns:
A `rich.panel.Panel` containing the information of the pipeline.
"""
from rich.console import Group
from rich.panel import Panel
information: List[Any] = [f"[bold][magenta]Name:[/bold][/magenta] {pipeline.name}"]
if pipeline.description:
information.append(
f"[bold][magenta]Description:[/bold][/magenta] {pipeline.description}"
)
information.extend(
[
"\n",
_build_steps_panel(pipeline),
"\n",
_build_steps_connection_panel(pipeline),
]
)
if any(
pipeline.dag.get_step(step).get(ROUTING_BATCH_FUNCTION_ATTR_NAME) is not None
for step in pipeline.dag.G.nodes
):
information.extend(
[
"\n",
_build_routing_batch_function_panel(pipeline),
]
)
return Panel(
Group(*information),
title="[magenta]Pipeline Information[/magenta]",
expand=False,
style="light_cyan3",
)
def _build_steps_panel(pipeline: "BasePipeline") -> "Panel":
"""Builds a panel to display the information of the steps.
Args:
pipeline: The pipeline
Returns:
A `rich.panel.Panel` containing the information of the steps.
"""
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
def _add_rows(
table: Table,
runtime_params: List[Dict[str, Any]],
prefix: str = "",
) -> None:
for param in runtime_params:
if isinstance(param, str):
_add_rows(table, runtime_params[param], f"{prefix}{param}.")
continue
# nested (for example `LLM` in `Task`)
if "runtime_parameters_info" in param:
_add_rows(
table=table,
runtime_params=param["runtime_parameters_info"],
prefix=f"{prefix}{param['name']}.",
)
# `LLM` special case
elif "keys" in param:
_add_rows(
table=table,
runtime_params=param["keys"],
prefix=f"{prefix}{param['name']}.",
)
return
else:
optional = param.get("optional", "")
if optional != "":
optional = "Yes" if optional else "No"
table.add_row(
prefix + param["name"], param.get("description"), optional
)
steps = []
for step_name, runtime_params in pipeline.get_runtime_parameters_info().items():
step = pipeline.dag.get_step(step_name)[STEP_ATTR_NAME]
class_name = step.__class__.__name__
table = Table(
title=f"{step.name} ([bold][magenta]{class_name}[/bold][/magenta])",
show_header=True,
header_style="bold magenta",
expand=True,
)
table.add_column("Runtime parameter", style="dim", width=60)
table.add_column("Description", width=100)
table.add_column("Optional", justify="right")
_add_rows(table, runtime_params)
steps.append(table)
return Panel(
Group(*steps),
title="[magenta]Steps[/magenta]",
expand=False,
padding=(1, 1, 0, 1),
style="light_cyan3",
)
def _build_steps_connection_panel(pipeline: "BasePipeline") -> "Panel":
"""Builds a panel to display the connections of the steps of the pipeline.
Args:
pipeline: The pipeline
Returns:
A `rich.panel.Panel` containing the connection of the steps.
"""
from rich.panel import Panel
from rich.table import Table
table = Table(show_header=True, header_style="bold magenta", expand=True)
table.add_column("From step", style="dim", width=18)
table.add_column("To steps", style="dim")
G = pipeline.dag.G
for node in G.nodes:
if successors := list(G.successors(node)):
# Convert list of successors to string
successors_str = ", ".join(map(str, successors))
table.add_row(str(node), successors_str)
continue
# If a node has no successors, indicate it as such
table.add_row(str(node), "No downstream steps")
return Panel(
table,
title="[magenta]Steps connections[/magenta]",
style="light_cyan3",
expand=True,
)
def _build_routing_batch_function_panel(pipeline: "BasePipeline") -> "Panel":
"""Builds a panel to display the routing batch function of the pipeline.
Args:
pipeline: The pipeline
Returns:
A `rich.panel.Panel` containing the routing batch function of the pipeline.
"""
from rich.panel import Panel
from rich.table import Table
table = Table(show_header=True, header_style="bold magenta", expand=True)
table.add_column("Step", style="dim", width=18)
table.add_column("Function", style="dim")
table.add_column("Description", width=90)
G = pipeline.dag.G
for step_name in G.nodes:
node = pipeline.dag.get_step(step_name)
if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME):
table.add_row(
step_name,
routing_batch_function.routing_function.__name__,
routing_batch_function.description,
)
continue
return Panel(
table,
title="[magenta]Routing Batch Function[/magenta]",
style="light_cyan3",
expand=True,
)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Final
# Steps related constants
DISTILABEL_METADATA_KEY: Final[str] = "distilabel_metadata"
# Cache
BASE_CACHE_DIR = Path.home() / ".cache" / "distilabel"
PIPELINES_CACHE_DIR = BASE_CACHE_DIR / "pipelines"
# Pipeline dag related constants
STEP_ATTR_NAME: Final[str] = "step"
INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue"
RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches"
ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function"
CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent"
# Pipeline execution related constants
PIPELINE_NAME_ENV_NAME = "DISTILABEL_PIPELINE_NAME"
PIPELINE_CACHE_ID_ENV_NAME = "DISTILABEL_PIPELINE_CACHE_ID"
SIGINT_HANDLER_CALLED_ENV_NAME = "sigint_handler_called"
# Data paths constants
STEPS_OUTPUTS_PATH = "steps_outputs"
STEPS_ARTIFACTS_PATH = "steps_artifacts"
# Distiset related constants
DISTISET_CONFIG_FOLDER: Final[str] = "distiset_configs"
DISTISET_ARTIFACTS_FOLDER: Final[str] = "artifacts"
PIPELINE_CONFIG_FILENAME: Final[str] = "pipeline.yaml"
PIPELINE_LOG_FILENAME: Final[str] = "pipeline.log"
# Docs page for the custom errors
DISTILABEL_DOCS_URL: Final[str] = "https://distilabel.argilla.io/latest/"
__all__ = [
"BASE_CACHE_DIR",
"CONVERGENCE_STEP_ATTR_NAME",
"DISTILABEL_DOCS_URL",
"DISTILABEL_METADATA_KEY",
"DISTISET_ARTIFACTS_FOLDER",
"DISTISET_CONFIG_FOLDER",
"INPUT_QUEUE_ATTR_NAME",
"LAST_BATCH_SENT_FLAG",
"PIPELINES_CACHE_DIR",
"PIPELINE_CONFIG_FILENAME",
"PIPELINE_LOG_FILENAME",
"RECEIVES_ROUTED_BATCHES_ATTR_NAME",
"ROUTING_BATCH_FUNCTION_ATTR_NAME",
"SIGINT_HANDLER_CALLED_ENV_NAME",
"STEPS_ARTIFACTS_PATH",
"STEPS_OUTPUTS_PATH",
"STEP_ATTR_NAME",
]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import json
import logging
import os.path as posixpath
import re
import sys
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
import fsspec
import yaml
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from datasets.filesystems import is_remote_filesystem
from huggingface_hub import DatasetCardData, HfApi, upload_file, upload_folder
from huggingface_hub.file_download import hf_hub_download
from pyarrow.lib import ArrowInvalid
from typing_extensions import Self
from distilabel.constants import (
DISTISET_ARTIFACTS_FOLDER,
DISTISET_CONFIG_FOLDER,
PIPELINE_CONFIG_FILENAME,
PIPELINE_LOG_FILENAME,
STEP_ATTR_NAME,
STEPS_ARTIFACTS_PATH,
STEPS_OUTPUTS_PATH,
)
from distilabel.utils.card.dataset_card import (
DistilabelDatasetCard,
size_categories_parser,
)
from distilabel.utils.docstring import get_bibtex, parse_google_docstring
from distilabel.utils.files import list_files_in_dir
from distilabel.utils.huggingface import get_hf_token
if TYPE_CHECKING:
from distilabel.pipeline._dag import DAG
def is_PIL_available() -> bool:
"""Checks if the PIL library is available.
Returns:
True if the PIL library is available, False otherwise.
"""
try:
importlib.util.find_spec("PIL")
except ImportError:
return False
return True
class Distiset(dict):
"""Convenient wrapper around `datasets.Dataset` to push to the Hugging Face Hub.
It's a dictionary where the keys correspond to the different leaf_steps from the internal
`DAG` and the values are `datasets.Dataset`.
Attributes:
_pipeline_path: Optional path to the `pipeline.yaml` file that generated the dataset.
Defaults to `None`.
_artifacts_path: Optional path to the directory containing the generated artifacts
by the pipeline steps. Defaults to `None`.
_log_filename_path: Optional path to the `pipeline.log` file that generated was written
by the pipeline. Defaults to `None`.
_citations: Optional list containing citations that will be included in the dataset
card. Defaults to `None`.
"""
_pipeline_path: Optional[Path] = None
_artifacts_path: Optional[Path] = None
_log_filename_path: Optional[Path] = None
_citations: Optional[List[str]] = None
def push_to_hub(
self,
repo_id: str,
private: bool = False,
token: Optional[str] = None,
generate_card: bool = True,
include_script: bool = False,
**kwargs: Any,
) -> None:
"""Pushes the `Distiset` to the Hugging Face Hub, each dataset will be pushed as a different configuration
corresponding to the leaf step that generated it.
Args:
repo_id:
The ID of the repository to push to in the following format: `<user>/<dataset_name>` or
`<org>/<dataset_name>`. Also accepts `<dataset_name>`, which will default to the namespace
of the logged-in user.
private:
Whether the dataset repository should be set to private or not. Only affects repository creation:
a repository that already exists will not be affected by that parameter.
token:
An optional authentication token for the Hugging Face Hub. If no token is passed, will default
to the token saved locally when logging in with `huggingface-cli login`. Will raise an error
if no token is passed and the user is not logged-in.
generate_card:
Whether to generate a dataset card or not. Defaults to True.
include_script:
Whether you want to push the pipeline script to the hugging face hub to share it.
If set to True, the name of the script that was run to create the distiset will be
automatically determined, and that will be the name of the file uploaded to your
repository. Take into account, this operation only makes sense for a distiset obtained
from calling `Pipeline.run()` method. Defaults to False.
**kwargs:
Additional keyword arguments to pass to the `push_to_hub` method of the `datasets.Dataset` object.
Raises:
ValueError: If no token is provided and couldn't be retrieved automatically.
"""
script_filename = sys.argv[0]
filename_py = (
script_filename.split("/")[-1]
if "/" in script_filename
else script_filename
)
script_path = Path.cwd() / script_filename
if token is None:
token = get_hf_token(self.__class__.__name__, "token")
for name, dataset in self.items():
dataset.push_to_hub(
repo_id=repo_id,
config_name=name,
private=private,
token=token,
**kwargs,
)
if self.artifacts_path:
upload_folder(
repo_id=repo_id,
folder_path=self.artifacts_path,
path_in_repo="artifacts",
token=token,
repo_type="dataset",
commit_message="Include pipeline artifacts",
)
if include_script and script_path.exists():
upload_file(
path_or_fileobj=script_path,
path_in_repo=filename_py,
repo_id=repo_id,
repo_type="dataset",
token=token,
commit_message="Include pipeline script",
)
if generate_card:
self._generate_card(
repo_id, token, include_script=include_script, filename_py=filename_py
)
def _get_card(
self,
repo_id: str,
token: Optional[str] = None,
include_script: bool = False,
filename_py: Optional[str] = None,
) -> DistilabelDatasetCard:
"""Generates the dataset card for the `Distiset`.
Note:
If `repo_id` and `token` are provided, it will extract the metadata from the README.md file
on the hub.
Args:
repo_id: Name of the repository to push to, or the path for the distiset if saved to disk.
token: The token to authenticate with the Hugging Face Hub.
We assume that if it's provided, the dataset will be in the Hugging Face Hub,
so the README metadata will be extracted from there.
include_script: Whether to upload the script to the hugging face repository.
filename_py: The name of the script. If `include_script` is True, the script will
be uploaded to the repository using this name, otherwise it won't be used.
Returns:
The dataset card for the `Distiset`.
"""
sample_records = {}
for name, dataset in self.items():
record = (
dataset[0] if not isinstance(dataset, dict) else dataset["train"][0]
)
if is_PIL_available():
from PIL import ImageFile
else:
ImageFile = None
for key, value in record.items():
# If the value is an image, we set it to an empty string to avoid the `README.md` to huge
if ImageFile and isinstance(value, ImageFile.ImageFile):
value = ""
# If list is too big, the `README.md` generated will be huge so we truncate it
elif isinstance(value, list):
length = len(value)
if length < 10:
continue
record[key] = value[:10]
record[key].append(
f"... (truncated - showing 10 of {length} elements)"
)
sample_records[name] = record
readme_metadata = {}
if repo_id and token:
readme_metadata = self._extract_readme_metadata(repo_id, token)
metadata = {
**readme_metadata,
"size_categories": size_categories_parser(
max(len(dataset) for dataset in self.values())
),
"tags": ["synthetic", "distilabel", "rlaif"],
}
card = DistilabelDatasetCard.from_template(
card_data=DatasetCardData(**metadata),
repo_id=repo_id,
sample_records=sample_records,
include_script=include_script,
filename_py=filename_py,
artifacts=self._get_artifacts_metadata(),
references=self.citations,
)
return card
def _get_artifacts_metadata(self) -> Dict[str, List[Dict[str, Any]]]:
"""Gets a dictionary with the metadata of the artifacts generated by the pipeline steps.
Returns:
A dictionary in which the key is the name of the step and the value is a list
of dictionaries, each of them containing the name and metadata of the step artifact.
"""
if not self.artifacts_path:
return {}
def iterdir_ignore_hidden(path: Path) -> Generator[Path, None, None]:
return (f for f in Path(path).iterdir() if not f.name.startswith("."))
artifacts_metadata = defaultdict(list)
for step_artifacts_dir in iterdir_ignore_hidden(self.artifacts_path):
step_name = step_artifacts_dir.stem
for artifact_dir in iterdir_ignore_hidden(step_artifacts_dir):
artifact_name = artifact_dir.stem
metadata_path = artifact_dir / "metadata.json"
metadata = json.loads(metadata_path.read_text())
artifacts_metadata[step_name].append(
{"name": artifact_name, "metadata": metadata}
)
return dict(artifacts_metadata)
def _extract_readme_metadata(
self, repo_id: str, token: Optional[str]
) -> Dict[str, Any]:
"""Extracts the metadata from the README.md file of the dataset repository.
We have to download the previous README.md file in the repo, extract the metadata from it,
and generate a dict again to be passed thorough the `DatasetCardData` object.
Args:
repo_id: The ID of the repository to push to, from the `push_to_hub` method.
token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method.
Returns:
The metadata extracted from the README.md file of the dataset repository as a dict.
"""
readme_path = Path(
hf_hub_download(repo_id, "README.md", repo_type="dataset", token=token)
)
# Remove the '---' from the metadata
metadata = re.findall(r"---\n(.*?)\n---", readme_path.read_text(), re.DOTALL)[0]
metadata = yaml.safe_load(metadata)
return metadata
def _generate_card(
self,
repo_id: str,
token: str,
include_script: bool = False,
filename_py: Optional[str] = None,
) -> None:
"""Generates a dataset card and pushes it to the Hugging Face Hub, and
if the `pipeline.yaml` path is available in the `Distiset`, uploads that
to the same repository.
Args:
repo_id: The ID of the repository to push to, from the `push_to_hub` method.
token: The token to authenticate with the Hugging Face Hub, from the `push_to_hub` method.
include_script: Whether to upload the script to the hugging face repository.
filename_py: The name of the script. If `include_script` is True, the script will
be uploaded to the repository using this name, otherwise it won't be used.
"""
card = self._get_card(
repo_id=repo_id,
token=token,
include_script=include_script,
filename_py=filename_py,
)
card.push_to_hub(
repo_id,
repo_type="dataset",
token=token,
)
if self.pipeline_path:
# If the pipeline.yaml is available, upload it to the Hugging Face Hub as well.
HfApi().upload_file(
path_or_fileobj=self.pipeline_path,
path_in_repo=PIPELINE_CONFIG_FILENAME,
repo_id=repo_id,
repo_type="dataset",
token=token,
)
if self.log_filename_path:
# The same we had with "pipeline.yaml" but with the log file.
HfApi().upload_file(
path_or_fileobj=self.log_filename_path,
path_in_repo=PIPELINE_LOG_FILENAME,
repo_id=repo_id,
repo_type="dataset",
token=token,
)
def train_test_split(
self,
train_size: float,
shuffle: bool = True,
seed: Optional[int] = None,
) -> Self:
"""Return a `Distiset` whose values will be a `datasets.DatasetDict` with two random train and test subsets.
Splits are created from the dataset according to `train_size` and `shuffle`.
Args:
train_size:
Float between `0.0` and `1.0` representing the proportion of the dataset to include in the test split.
It will be applied to all the datasets in the `Distiset`.
shuffle: Whether or not to shuffle the data before splitting
seed:
A seed to initialize the default BitGenerator, passed to the underlying method.
Returns:
The `Distiset` with the train-test split applied to all the datasets.
"""
assert 0 < train_size < 1, "train_size must be a float between 0 and 1"
for name, dataset in self.items():
self[name] = dataset.train_test_split(
train_size=train_size,
shuffle=shuffle,
seed=seed,
)
return self
def save_to_disk(
self,
distiset_path: PathLike,
max_shard_size: Optional[Union[str, int]] = None,
num_shards: Optional[int] = None,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
save_card: bool = True,
save_pipeline_config: bool = True,
save_pipeline_log: bool = True,
) -> None:
r"""
Saves a `Distiset` to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
In case you want to save the `Distiset` in a remote filesystem, you can pass the `storage_options` parameter
as you would do with `datasets`'s `Dataset.save_to_disk` method: [see example](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets)
Args:
distiset_path: Path where you want to save the `Distiset`. It can be a local path
(e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`)
max_shard_size: The maximum size of the dataset shards to be uploaded to the hub.
If expressed as a string, needs to be digits followed by a unit (like `"50MB"`).
Defaults to `None`.
num_shards: Number of shards to write. By default the number of shards depends on
`max_shard_size` and `num_proc`. Defaults to `None`.
num_proc: Number of processes when downloading and generating the dataset locally.
Multiprocessing is disabled by default. Defaults to `None`.
storage_options: Key/value pairs to be passed on to the file-system backend, if any.
Defaults to `None`.
save_card: Whether to save the dataset card. Defaults to `True`.
save_pipeline_config: Whether to save the pipeline configuration file (aka the `pipeline.yaml` file).
Defaults to `True`.
save_pipeline_log: Whether to save the pipeline log file (aka the `pipeline.log` file).
Defaults to `True`.
Examples:
```python
# Save your distiset in a local folder:
distiset.save_to_disk(distiset_path="my-distiset")
# Save your distiset in a remote storage:
storage_options = {
"key": os.environ["S3_ACCESS_KEY"],
"secret": os.environ["S3_SECRET_KEY"],
"client_kwargs": {
"endpoint_url": os.environ["S3_ENDPOINT_URL"],
"region_name": os.environ["S3_REGION"],
},
}
distiset.save_to_disk(distiset_path="my-distiset", storage_options=storage_options)
```
"""
distiset_path = str(distiset_path)
for name, dataset in self.items():
dataset.save_to_disk(
f"{distiset_path}/{name}",
max_shard_size=max_shard_size,
num_shards=num_shards,
num_proc=num_proc,
storage_options=storage_options,
)
distiset_config_folder = posixpath.join(distiset_path, DISTISET_CONFIG_FOLDER)
fs: fsspec.AbstractFileSystem
fs, _, _ = fsspec.get_fs_token_paths(
distiset_config_folder, storage_options=storage_options
)
fs.makedirs(distiset_config_folder, exist_ok=True)
if self.artifacts_path:
distiset_artifacts_folder = posixpath.join(
distiset_path, DISTISET_ARTIFACTS_FOLDER
)
fs.copy(str(self.artifacts_path), distiset_artifacts_folder, recursive=True)
if save_card:
# NOTE: Currently the card is not the same if we write to disk or push to the HF hub,
# as we aren't generating the README copying/updating the data from the dataset repo.
card = self._get_card(repo_id=Path(distiset_path).stem, token=None)
new_filename = posixpath.join(distiset_config_folder, "README.md")
if storage_options:
# Write the card the same way as DatasetCard.save does:
with fs.open(new_filename, "w", newline="", encoding="utf-8") as f:
f.write(str(card))
else:
card.save(new_filename)
# Write our internal files to the distiset folder by copying them to the distiset folder.
if save_pipeline_config and self.pipeline_path:
new_filename = posixpath.join(
distiset_config_folder, PIPELINE_CONFIG_FILENAME
)
if self.pipeline_path.exists() and (not fs.isfile(new_filename)):
data = yaml.safe_load(self.pipeline_path.read_text())
with fs.open(new_filename, "w", encoding="utf-8") as f:
yaml.dump(data, f, default_flow_style=False)
if save_pipeline_log and self.log_filename_path:
new_filename = posixpath.join(distiset_config_folder, PIPELINE_LOG_FILENAME)
if self.log_filename_path.exists() and (not fs.isfile(new_filename)):
data = self.log_filename_path.read_text()
with fs.open(new_filename, "w", encoding="utf-8") as f:
f.write(data)
@classmethod
def load_from_disk(
cls,
distiset_path: PathLike,
keep_in_memory: Optional[bool] = None,
storage_options: Optional[Dict[str, Any]] = None,
download_dir: Optional[PathLike] = None,
) -> Self:
"""Loads a dataset that was previously saved using `Distiset.save_to_disk` from a dataset
directory, or from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
Args:
distiset_path: Path ("dataset/train") or remote URI ("s3://bucket/dataset/train").
keep_in_memory: Whether to copy the dataset in-memory, see `datasets.Dataset.load_from_disk``
for more information. Defaults to `None`.
storage_options: Key/value pairs to be passed on to the file-system backend, if any.
Defaults to `None`.
download_dir: Optional directory to download the dataset to. Defaults to None,
in which case it will create a temporary directory.
Returns:
A `Distiset` loaded from disk, it should be a `Distiset` object created using `Distiset.save_to_disk`.
"""
original_distiset_path = str(distiset_path)
fs: fsspec.AbstractFileSystem
fs, _, [distiset_path] = fsspec.get_fs_token_paths( # type: ignore
original_distiset_path, storage_options=storage_options
)
dest_distiset_path = distiset_path
assert fs.isdir(
original_distiset_path
), "`distiset_path` must be a `PathLike` object pointing to a folder or a URI of a remote filesystem."
has_config = False
has_artifacts = False
distiset = cls()
if is_remote_filesystem(fs):
src_dataset_path = distiset_path
if download_dir:
dest_distiset_path = download_dir
else:
dest_distiset_path = Dataset._build_local_temp_path(src_dataset_path) # type: ignore
fs.download(src_dataset_path, dest_distiset_path.as_posix(), recursive=True) # type: ignore
# Now we should have the distiset locally, so we can read those files
for folder in Path(dest_distiset_path).iterdir():
if folder.stem == DISTISET_CONFIG_FOLDER:
has_config = True
continue
elif folder.stem == DISTISET_ARTIFACTS_FOLDER:
has_artifacts = True
continue
distiset[folder.stem] = load_from_disk(
str(folder),
keep_in_memory=keep_in_memory,
)
# From the config folder we just need to point to the files. Once downloaded we set the path to point to point to the files. Once downloaded we set the path
# to wherever they are.
if has_config:
distiset_config_folder = posixpath.join(
dest_distiset_path, DISTISET_CONFIG_FOLDER
)
pipeline_path = posixpath.join(
distiset_config_folder, PIPELINE_CONFIG_FILENAME
)
if Path(pipeline_path).exists():
distiset.pipeline_path = Path(pipeline_path)
log_filename_path = posixpath.join(
distiset_config_folder, PIPELINE_LOG_FILENAME
)
if Path(log_filename_path).exists():
distiset.log_filename_path = Path(log_filename_path)
if has_artifacts:
distiset.artifacts_path = Path(
posixpath.join(dest_distiset_path, DISTISET_ARTIFACTS_FOLDER)
)
return distiset
@property
def pipeline_path(self) -> Union[Path, None]:
"""Returns the path to the `pipeline.yaml` file that generated the `Pipeline`."""
return self._pipeline_path
@pipeline_path.setter
def pipeline_path(self, path: PathLike) -> None:
self._pipeline_path = Path(path)
@property
def artifacts_path(self) -> Union[Path, None]:
"""Returns the path to the directory containing the artifacts generated by the steps
of the pipeline."""
return self._artifacts_path
@artifacts_path.setter
def artifacts_path(self, path: PathLike) -> None:
self._artifacts_path = Path(path)
@property
def log_filename_path(self) -> Union[Path, None]:
"""Returns the path to the `pipeline.log` file that generated the `Pipeline`."""
return self._log_filename_path
@log_filename_path.setter
def log_filename_path(self, path: PathLike) -> None:
self._log_filename_path = Path(path)
@property
def citations(self) -> Union[List[str], None]:
"""Bibtex references to be included in the README."""
return self._citations
@citations.setter
def citations(self, citations_: List[str]) -> None:
self._citations = sorted(set(citations_))
def __repr__(self):
# Copy from `datasets.DatasetDict.__repr__`.
repr = "\n".join([f"{k}: {v}" for k, v in self.items()])
repr = re.sub(r"^", " " * 4, repr, count=0, flags=re.M)
return f"Distiset({{\n{repr}\n}})"
def transform_columns_to_image(self, columns: Union[str, list[str]]) -> Self:
"""Transforms the columns of the dataset to `PIL.Image` objects.
Args:
columns: Column or list of columns to transform.
Returns:
Transforms the columns of the dataset to `PIL.Image` objects before pushing,
so the Hub treats them as Image objects and can be rendered in the dataset
viewer, and cast them to be automatically transformed when downloading
the dataset back.
"""
from datasets import Image
from distilabel.models.image_generation.utils import image_from_str
columns = [columns] if isinstance(columns, str) else columns
def cast_to_image(row: dict) -> dict:
for column in columns:
row[column] = image_from_str(row[column])
return row
for name, dataset in self.items():
# In case train_test_split was called
if isinstance(dataset, DatasetDict):
for split, dataset_split in dataset.items():
dataset_split = dataset_split.map(cast_to_image)
for column in columns:
if column in dataset_split.column_names:
dataset_split = dataset_split.cast_column(
column, Image(decode=True)
)
self[name][split] = dataset_split
else:
dataset = dataset.map(cast_to_image)
for column in columns:
if column in dataset.column_names:
dataset = dataset.cast_column(column, Image(decode=True))
self[name] = dataset
return self
def create_distiset( # noqa: C901
data_dir: Path,
pipeline_path: Optional[Path] = None,
log_filename_path: Optional[Path] = None,
enable_metadata: bool = False,
dag: Optional["DAG"] = None,
) -> Distiset:
"""Creates a `Distiset` from the buffer folder.
This function is intended to be used as a helper to create a `Distiset` from from the folder
where the cached data was written by the `_WriteBuffer`.
Args:
data_dir: Folder where the data buffers were written by the `_WriteBuffer`.
It should correspond to `CacheLocation.data`.
pipeline_path: Optional path to the pipeline.yaml file that generated the dataset.
Internally this will be passed to the `Distiset` object on creation to allow
uploading the `pipeline.yaml` file to the repo upon `Distiset.push_to_hub`.
log_filename_path: Optional path to the pipeline.log file that was generated during the pipeline run.
Internally this will be passed to the `Distiset` object on creation to allow
uploading the `pipeline.log` file to the repo upon `Distiset.push_to_hub`.
enable_metadata: Whether to include the distilabel metadata column in the dataset or not.
Defaults to `False`.
dag: DAG contained in a `Pipeline`. If informed, will be used to extract the references/
citations from it.
Returns:
The dataset created from the buffer folder, where the different leaf steps will
correspond to different configurations of the dataset.
Examples:
```python
from pathlib import Path
distiset = create_distiset(Path.home() / ".cache/distilabel/pipelines/path-to-pipe-hashname")
```
"""
from distilabel.constants import DISTILABEL_METADATA_KEY
logger = logging.getLogger("distilabel.distiset")
steps_outputs_dir = data_dir / STEPS_OUTPUTS_PATH
distiset = Distiset()
for file in steps_outputs_dir.iterdir():
if file.is_file():
continue
files = [str(file) for file in list_files_in_dir(file)]
if files:
try:
ds = load_dataset(
"parquet", name=file.stem, data_files={"train": files}
)
if not enable_metadata and DISTILABEL_METADATA_KEY in ds.column_names:
ds = ds.remove_columns(DISTILABEL_METADATA_KEY)
distiset[file.stem] = ds
except ArrowInvalid:
logger.warning(f"❌ Failed to load the subset from '{file}' directory.")
continue
else:
logger.warning(
f"No output files for step '{file.stem}', can't create a dataset."
" Did the step produce any data?"
)
# If there's only one dataset i.e. one config, then set the config name to `default`
if len(distiset.keys()) == 1:
distiset["default"] = distiset.pop(list(distiset.keys())[0])
# If there's any artifact set the `artifacts_path` so they can be uploaded
steps_artifacts_dir = data_dir / STEPS_ARTIFACTS_PATH
if any(steps_artifacts_dir.rglob("*")):
distiset.artifacts_path = steps_artifacts_dir
# Include `pipeline.yaml` if exists
if pipeline_path:
distiset.pipeline_path = pipeline_path
else:
# If the pipeline path is not provided, try to find it in the parent directory
# and assume that's the wanted file.
pipeline_path = steps_outputs_dir.parent / "pipeline.yaml"
if pipeline_path.exists():
distiset.pipeline_path = pipeline_path
# Include `pipeline.log` if exists
if log_filename_path:
distiset.log_filename_path = log_filename_path
else:
log_filename_path = steps_outputs_dir.parent / "pipeline.log"
if log_filename_path.exists():
distiset.log_filename_path = log_filename_path
if dag:
distiset._citations = _grab_citations(dag)
return distiset
def _grab_citations(dag: "DAG") -> List[str]:
"""Extracts the citations from the steps that form the DAG.
Args:
dag: `DAG` contained in the pipeline that created the `Distiset`.
Returns:
List of citations to add to the `Distiset`.
"""
citations = []
for step_name in dag:
step_info = parse_google_docstring(dag.get_step(step_name)[STEP_ATTR_NAME])
if cites := step_info["citations"]:
citations.extend(cites)
continue
# If there were no citations but we have references with arxiv URLs, try to extract
# the bixtex citations from those
if references := step_info["references"]:
bibtex_refs = []
for ref in references.values():
try:
bibtex_refs.append(get_bibtex(ref))
except ValueError:
# No need to inform in this case, it's noise
pass
except AttributeError as e:
print(
f"Couldn't obtain the bibtex format for the ref: '{ref}', error: {e}"
)
except Exception as e:
print(f"Untracked error: {e}")
citations.extend(bibtex_refs)
return citations
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: E402
import warnings
deprecation_message = (
"Importing from 'distilabel.embeddings' is deprecated and will be removed in a version 1.7.0. "
"Import from 'distilabel.models' instead."
)
warnings.warn(deprecation_message, DeprecationWarning, stacklevel=2)
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
from distilabel.models.embeddings.vllm import vLLMEmbeddings
__all__ = [
"Embeddings",
"SentenceTransformerEmbeddings",
"vLLMEmbeddings",
]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Idea from: https://github.com/vllm-project/vllm/blob/main/vllm/envs.py
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from distilabel import constants
if TYPE_CHECKING:
DISTILABEL_LOG_LEVEL: str = "INFO"
DISTILABEL_PIPELINE_NAME: Optional[str] = None
DISTILABEL_PIPELINE_CACHE_ID: Optional[str] = None
DISTILABEL_CACHE_DIR: Optional[str] = None
ENVIRONMENT_VARIABLES: Dict[str, Callable[[], Any]] = {
# `distilabel` logging level.
"DISTILABEL_LOG_LEVEL": lambda: os.getenv("DISTILABEL_LOG_LEVEL", "INFO").upper(),
# The name of the `distilabel` pipeline currently running.
constants.PIPELINE_NAME_ENV_NAME: lambda: os.getenv(
constants.PIPELINE_NAME_ENV_NAME, None
),
# The cache ID of the `distilabel` pipeline currently running.
constants.PIPELINE_CACHE_ID_ENV_NAME: lambda: os.getenv(
constants.PIPELINE_CACHE_ID_ENV_NAME, None
),
# The cache ID of the `distilabel` pipeline currently running.
"DISTILABEL_CACHE_DIR": lambda: os.getenv("DISTILABEL_CACHE_DIR", None),
}
def __getattr__(name: str) -> Any:
# lazy evaluation of environment variables
if name in ENVIRONMENT_VARIABLES:
return ENVIRONMENT_VARIABLES[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__() -> List[str]:
return list(ENVIRONMENT_VARIABLES.keys())
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from distilabel.constants import DISTILABEL_DOCS_URL
# The sitemap can be visited for the full list of pages:
# SITEMAP_URL: Final[str] = "https://distilabel.argilla.io/latest/sitemap.xml"
class DistilabelError:
"""A mixin class for common functionality shared by all Distilabel-specific errors.
Attributes:
message: A message describing the error.
page: An optional error code from PydanticErrorCodes enum.
Examples:
```python
raise DistilabelUserError("This is an error message.")
This is an error message.
raise DistilabelUserError("This is an error message.", page="sections/getting_started/faq/")
This is an error message.
For further information visit 'https://distilabel.argilla.io/latest/sections/getting_started/faq/'
```
"""
def __init__(self, message: str, *, page: Optional[str] = None) -> None:
self.message = message
self.page = page
def __str__(self) -> str:
if self.page is None:
return self.message
else:
return f"{self.message}\n\nFor further information visit '{DISTILABEL_DOCS_URL}{self.page}'"
class DistilabelUserError(DistilabelError, ValueError):
"""ValueError that we can redirect to a given page in the documentation."""
pass
class DistilabelTypeError(DistilabelError, TypeError):
"""TypeError that we can redirect to a given page in the documentation."""
pass
class DistilabelNotImplementedError(DistilabelError, NotImplementedError):
"""NotImplementedError that we can redirect to a given page in the documentation."""
pass
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
class DistilabelException(Exception):
"""Base exception (can be gracefully handled) for `distilabel` framework."""
pass
class DistilabelGenerationException(DistilabelException):
"""Base exception for `LLM` generation errors."""
pass
class DistilabelOfflineBatchGenerationNotFinishedException(
DistilabelGenerationException
):
"""Exception raised when a batch generation is not finished."""
jobs_ids: Tuple[str, ...]
def __init__(self, jobs_ids: Tuple[str, ...]) -> None:
self.jobs_ids = jobs_ids
super().__init__(f"Batch generation with jobs_ids={jobs_ids} is not finished")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment