Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from pydantic import Field, PrivateAttr, SecretStr
try:
import argilla as rg
except ImportError:
pass
from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from argilla import Argilla, Dataset
from distilabel.typing import StepColumns, StepOutput
_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
_ARGILLA_API_KEY_ENV_VAR_NAME = "ARGILLA_API_KEY"
class ArgillaBase(Step, ABC):
"""Abstract step that provides a class to subclass from, that contains the boilerplate code
required to interact with Argilla, as well as some extra validations on top of it. It also defines
the abstract methods that need to be implemented in order to add a new dataset type as a step.
Note:
This class is not intended to be instanced directly, but via subclass.
Attributes:
dataset_name: The name of the dataset in Argilla where the records will be added.
dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
`None`, which means it will be created in the default workspace.
api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
the `ARGILLA_API_URL` environment variable.
api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
be read from the `ARGILLA_API_KEY` environment variable.
Runtime parameters:
- `dataset_name`: The name of the dataset in Argilla where the records will be
added.
- `dataset_workspace`: The workspace where the dataset will be created in Argilla.
Defaults to `None`, which means it will be created in the default workspace.
- `api_url`: The base URL to use for the Argilla API requests.
- `api_key`: The API key to authenticate the requests to the Argilla API.
Input columns:
- dynamic, based on the `inputs` value provided
"""
dataset_name: RuntimeParameter[str] = Field(
default=None, description="The name of the dataset in Argilla."
)
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The workspace where the dataset will be created in Argilla. Defaults "
"to `None` which means it will be created in the default workspace.",
)
api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Argilla API.",
)
_client: Optional["Argilla"] = PrivateAttr(...)
_dataset: Optional["Dataset"] = PrivateAttr(...)
def model_post_init(self, __context: Any) -> None:
"""Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
super().model_post_init(__context)
if importlib.util.find_spec("argilla") is None:
raise ImportError(
"Argilla is not installed. Please install it using `pip install 'distilabel[argilla]'`."
)
def _client_init(self) -> None:
"""Initializes the Argilla API client with the provided `api_url` and `api_key`."""
try:
self._client = rg.Argilla( # type: ignore
api_url=self.api_url,
api_key=self.api_key.get_secret_value(), # type: ignore
headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
if isinstance(self.api_url, str)
and "hf.space" in self.api_url
and "HF_TOKEN" in os.environ
else {},
)
except Exception as e:
raise DistilabelUserError(
f"Failed to initialize the Argilla API: {e}",
page="sections/how_to_guides/advanced/argilla/",
) from e
@property
def _dataset_exists_in_workspace(self) -> bool:
"""Checks if the dataset already exists in Argilla in the provided workspace if any.
Returns:
`True` if the dataset exists, `False` otherwise.
"""
return (
self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace,
)
is not None
)
@property
def outputs(self) -> "StepColumns":
"""The outputs of the step is an empty list, since the steps subclassing from this one, will
always be leaf nodes and won't propagate the inputs neither generate any outputs.
"""
return []
def load(self) -> None:
"""Method to perform any initialization logic before the `process` method is
called. For example, to load an LLM, stablish a connection to a database, etc.
"""
super().load()
if self.api_url is None or self.api_key is None:
raise DistilabelUserError(
"`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
" provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
" and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`.",
page="sections/how_to_guides/advanced/argilla/",
)
self._client_init()
@property
@abstractmethod
def inputs(self) -> "StepColumns": ...
@abstractmethod
def process(self, *inputs: StepInput) -> "StepOutput": ...
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, Any, Dict, List, Union
from pydantic import PrivateAttr
from typing_extensions import override
try:
import argilla as rg
except ImportError:
pass
from distilabel.errors import DistilabelUserError
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput
if TYPE_CHECKING:
from argilla import RatingQuestion, Suggestion, TextField, TextQuestion
from distilabel.typing import StepOutput
class PreferenceToArgilla(ArgillaBase):
"""Creates a preference dataset in Argilla.
Step that creates a dataset in Argilla during the load phase, and then pushes the input
batches into it as records. This dataset is a preference dataset, where there's one field
for the instruction and one extra field per each generation within the same record, and then
a rating question per each of the generation fields. The rating question asks the annotator to
set a rating from 1 to 5 for each of the provided generations.
Note:
This step is meant to be used in conjunction with the `UltraFeedback` step, or any other step
generating both ratings and responses for a given set of instruction and generations for the
given instruction. But alternatively, it can also be used with any other task or step generating
only the `instruction` and `generations`, as the `ratings` and `rationales` are optional.
Attributes:
num_generations: The number of generations to include in the dataset.
dataset_name: The name of the dataset in Argilla.
dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
`None`, which means it will be created in the default workspace.
api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
the `ARGILLA_API_URL` environment variable.
api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
be read from the `ARGILLA_API_KEY` environment variable.
Runtime parameters:
- `api_url`: The base URL to use for the Argilla API requests.
- `api_key`: The API key to authenticate the requests to the Argilla API.
Input columns:
- instruction (`str`): The instruction that was used to generate the completion.
- generations (`List[str]`): The completion that was generated based on the input instruction.
- ratings (`List[str]`, optional): The ratings for the generations. If not provided, the
generated ratings won't be pushed to Argilla.
- rationales (`List[str]`, optional): The rationales for the ratings. If not provided, the
generated rationales won't be pushed to Argilla.
Examples:
Push a preference dataset to an Argilla instance:
```python
from distilabel.steps import PreferenceToArgilla
to_argilla = PreferenceToArgilla(
num_generations=2,
api_url="https://dibt-demo-argilla-space.hf.space/",
api_key="api.key",
dataset_name="argilla_dataset",
dataset_workspace="my_workspace",
)
to_argilla.load()
result = next(
to_argilla.process(
[
{
"instruction": "instruction",
"generations": ["first_generation", "second_generation"],
}
],
)
)
# >>> result
# [{'instruction': 'instruction', 'generations': ['first_generation', 'second_generation']}]
```
It can also include ratings and rationales:
```python
result = next(
to_argilla.process(
[
{
"instruction": "instruction",
"generations": ["first_generation", "second_generation"],
"ratings": ["4", "5"],
"rationales": ["rationale for 4", "rationale for 5"],
}
],
)
)
# >>> result
# [
# {
# 'instruction': 'instruction',
# 'generations': ['first_generation', 'second_generation'],
# 'ratings': ['4', '5'],
# 'rationales': ['rationale for 4', 'rationale for 5']
# }
# ]
```
"""
num_generations: int
_id: str = PrivateAttr(default="id")
_instruction: str = PrivateAttr(...)
_generations: str = PrivateAttr(...)
_ratings: str = PrivateAttr(...)
_rationales: str = PrivateAttr(...)
def load(self) -> None:
"""Sets the `_instruction` and `_generations` attributes based on the `inputs_mapping`, otherwise
uses the default values; and then uses those values to create a `FeedbackDataset` suited for
the text-generation scenario. And then it pushes it to Argilla.
"""
super().load()
# Both `instruction` and `generations` will be used as the fields of the dataset
self._instruction = self.input_mappings.get("instruction", "instruction")
self._generations = self.input_mappings.get("generations", "generations")
# Both `ratings` and `rationales` will be used as suggestions to the default questions of the dataset
self._ratings = self.input_mappings.get("ratings", "ratings")
self._rationales = self.input_mappings.get("rationales", "rationales")
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace, # type: ignore
)
for field in _dataset.fields:
if not isinstance(field, rg.TextField):
continue
if (
field.name
not in [self._id, self._instruction] # type: ignore
+ [
f"{self._generations}-{idx}"
for idx in range(self.num_generations)
]
and field.required
):
raise DistilabelUserError(
f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
f" already exists, but contains at least a required field that is"
f" neither `{self._id}`, `{self._instruction}`, nor `{self._generations}`"
f" (one per generation starting from 0 up to {self.num_generations - 1}).",
page="components-gallery/steps/preferencetoargilla/",
)
self._dataset = _dataset
else:
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
*self._generation_fields(), # type: ignore
],
questions=self._rating_rationale_pairs(), # type: ignore
)
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._dataset = _dataset.create()
def _generation_fields(self) -> List["TextField"]:
"""Method to generate the fields for each of the generations.
Returns:
A list containing `TextField`s for each text generation.
"""
return [
rg.TextField( # type: ignore
name=f"{self._generations}-{idx}",
title=f"{self._generations}-{idx}",
required=True if idx == 0 else False,
)
for idx in range(self.num_generations)
]
def _rating_rationale_pairs(
self,
) -> List[Union["RatingQuestion", "TextQuestion"]]:
"""Method to generate the rating and rationale questions for each of the generations.
Returns:
A list of questions containing a `RatingQuestion` and `TextQuestion` pair for
each text generation.
"""
questions = []
for idx in range(self.num_generations):
questions.extend(
[
rg.RatingQuestion( # type: ignore
name=f"{self._generations}-{idx}-rating",
title=f"Rate {self._generations}-{idx} given {self._instruction}.",
description=f"Ignore this question if the corresponding `{self._generations}-{idx}` field is not available."
if idx != 0
else None,
values=[1, 2, 3, 4, 5],
required=True if idx == 0 else False,
),
rg.TextQuestion( # type: ignore
name=f"{self._generations}-{idx}-rationale",
title=f"Specify the rationale for {self._generations}-{idx}'s rating.",
description=f"Ignore this question if the corresponding `{self._generations}-{idx}` field is not available."
if idx != 0
else None,
required=False,
),
]
)
return questions
@property
def inputs(self) -> List[str]:
"""The inputs for the step are the `instruction` and the `generations`. Optionally, one could also
provide the `ratings` and the `rationales` for the generations."""
return ["instruction", "generations"]
@property
def optional_inputs(self) -> List[str]:
"""The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
return ["ratings", "rationales"]
def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]:
"""Method to generate the suggestions for the `rg.Record` based on the input.
Returns:
A list of `Suggestion`s for the rating and rationales questions.
"""
# Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
suggestions = []
# If `ratings` is in `input`, then add those as suggestions
if self._ratings in input:
suggestions.extend(
[
rg.Suggestion( # type: ignore
value=rating,
question_name=f"{self._generations}-{idx}-rating",
)
for idx, rating in enumerate(input[self._ratings])
if rating is not None
and isinstance(rating, int)
and rating in [1, 2, 3, 4, 5]
],
)
# If `rationales` is in `input`, then add those as suggestions
if self._rationales in input:
suggestions.extend(
[
rg.Suggestion( # type: ignore
value=rationale,
question_name=f"{self._generations}-{idx}-rationale",
)
for idx, rationale in enumerate(input[self._rationales])
if rationale is not None and isinstance(rationale, str)
],
)
return suggestions
@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Creates and pushes the records as `rg.Record`s to the Argilla dataset.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Returns:
A list of Python dictionaries with the outputs of the task.
"""
records = []
for input in inputs:
# Generate the SHA-256 hash of the instruction to use it as the metadata
instruction_id = hashlib.sha256(
input["instruction"].encode("utf-8") # type: ignore
).hexdigest()
generations = {
f"{self._generations}-{idx}": generation
for idx, generation in enumerate(input["generations"]) # type: ignore
}
records.append( # type: ignore
rg.Record( # type: ignore
fields={
"id": instruction_id,
"instruction": input["instruction"], # type: ignore
**generations,
},
suggestions=self._add_suggestions_if_any(input), # type: ignore
)
)
self._dataset.records.log(records) # type: ignore
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, List
from pydantic import PrivateAttr
from typing_extensions import override
try:
import argilla as rg
except ImportError:
pass
from distilabel.errors import DistilabelUserError
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput
if TYPE_CHECKING:
from distilabel.typing import StepOutput
class TextGenerationToArgilla(ArgillaBase):
"""Creates a text generation dataset in Argilla.
`Step` that creates a dataset in Argilla during the load phase, and then pushes the input
batches into it as records. This dataset is a text-generation dataset, where there's one field
per each input, and then a label question to rate the quality of the completion in either bad
(represented with 👎) or good (represented with 👍).
Note:
This step is meant to be used in conjunction with a `TextGeneration` step and no column mapping
is needed, as it will use the default values for the `instruction` and `generation` columns.
Attributes:
dataset_name: The name of the dataset in Argilla.
dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
`None`, which means it will be created in the default workspace.
api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
the `ARGILLA_API_URL` environment variable.
api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
be read from the `ARGILLA_API_KEY` environment variable.
Runtime parameters:
- `api_url`: The base URL to use for the Argilla API requests.
- `api_key`: The API key to authenticate the requests to the Argilla API.
Input columns:
- instruction (`str`): The instruction that was used to generate the completion.
- generation (`str` or `List[str]`): The completions that were generated based on the input instruction.
Examples:
Push a text generation dataset to an Argilla instance:
```python
from distilabel.steps import PreferenceToArgilla
to_argilla = TextGenerationToArgilla(
num_generations=2,
api_url="https://dibt-demo-argilla-space.hf.space/",
api_key="api.key",
dataset_name="argilla_dataset",
dataset_workspace="my_workspace",
)
to_argilla.load()
result = next(
to_argilla.process(
[
{
"instruction": "instruction",
"generation": "generation",
}
],
)
)
# >>> result
# [{'instruction': 'instruction', 'generation': 'generation'}]
```
"""
_id: str = PrivateAttr(default="id")
_instruction: str = PrivateAttr(...)
_generation: str = PrivateAttr(...)
def load(self) -> None:
"""Sets the `_instruction` and `_generation` attributes based on the `inputs_mapping`, otherwise
uses the default values; and then uses those values to create a `FeedbackDataset` suited for
the text-generation scenario. And then it pushes it to Argilla.
"""
super().load()
self._instruction = self.input_mappings.get("instruction", "instruction")
self._generation = self.input_mappings.get("generation", "generation")
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace, # type: ignore
)
for field in _dataset.fields:
if not isinstance(field, rg.TextField): # type: ignore
continue
if (
field.name not in [self._id, self._instruction, self._generation]
and field.required
):
raise DistilabelUserError(
f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
f" already exists, but contains at least a required field that is"
f" neither `{self._id}`, `{self._instruction}`, nor `{self._generation}`,"
" so it cannot be reused for this dataset.",
page="components-gallery/steps/textgenerationtoargilla/",
)
self._dataset = _dataset
else:
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
rg.TextField(name=self._generation, title=self._generation), # type: ignore
],
questions=[
rg.LabelQuestion( # type: ignore
name="quality",
title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
labels={"bad": "👎", "good": "👍"}, # type: ignore
)
],
)
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._dataset = _dataset.create()
@property
def inputs(self) -> List[str]:
"""The inputs for the step are the `instruction` and the `generation`."""
return ["instruction", "generation"]
@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Creates and pushes the records as FeedbackRecords to the Argilla dataset.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Returns:
A list of Python dictionaries with the outputs of the task.
"""
records = []
for input in inputs:
# Generate the SHA-256 hash of the instruction to use it as the metadata
instruction_id = hashlib.sha256(
input["instruction"].encode("utf-8")
).hexdigest()
generations = input["generation"]
# If the `generation` is not a list, then convert it into a list
if not isinstance(generations, list):
generations = [generations]
# Create a `generations_set` to avoid adding duplicates
generations_set = set()
for generation in generations:
# If the generation is already in the set, then skip it
if generation in generations_set:
continue
# Otherwise, add it to the set
generations_set.add(generation)
records.append(
rg.Record( # type: ignore
fields={
self._id: instruction_id,
self._instruction: input["instruction"],
self._generation: generation,
},
),
)
self._dataset.records.log(records) # type: ignore
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import re
from abc import ABC, abstractmethod
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
overload,
)
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, PrivateAttr
from typing_extensions import Annotated, Self
from distilabel.errors import DistilabelTypeError, DistilabelUserError
from distilabel.mixins.requirements import RequirementsMixin
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersMixin,
)
from distilabel.mixins.signature import SignatureMixin
from distilabel.utils.serialization import _Serializable, write_json
from distilabel.utils.typing_ import is_parameter_annotated_with
if TYPE_CHECKING:
from logging import Logger
from distilabel.pipeline.base import BasePipeline
from distilabel.pipeline.routing_batch_function import RoutingBatchFunction
from distilabel.typing import (
DownstreamConnectable,
DownstreamConnectableSteps,
GeneratorStepOutput,
StepColumns,
StepOutput,
UpstreamConnectableSteps,
)
DEFAULT_INPUT_BATCH_SIZE = 50
_STEP_INPUT_ANNOTATION = "distilabel_step_input"
StepInput = Annotated[List[Dict[str, Any]], _STEP_INPUT_ANNOTATION]
"""StepInput is just an `Annotated` alias of the typing `List[Dict[str, Any]]` with
extra metadata that allows `distilabel` to perform validations over the `process` step
method defined in each `Step`"""
# Pattern to convert PascalCase to snake_case
PATTERN_PASCAL_NAME = re.compile(r"(?<!^)(?=[A-Z])")
def _infer_step_name(
step_cls_name: str, pipeline: Optional["BasePipeline"] = None
) -> str:
"""Infer the name of the step based on the class name and the pipeline.
If a `Pipeline` is given (the general case), it will check if the name already exists
in the steps of the `DAG`, to add a number at the end of the name.
Args:
step_cls_name: The step class name, as obtained by `type(cls).__name__`.
pipeline: The `Pipeline` the step belongs to, can be `None` if the step is created
outside of a `Pipeline`.
Returns:
A name for the step.
Example:
```python
>>> _infer_step_name("StepWithOnePreviousStep", None)
'step_with_one_previous_step'
```
"""
name = re.sub(PATTERN_PASCAL_NAME, "_", step_cls_name).lower() + "_0"
if pipeline:
# Check the name doesn't already exist in the pipeline
step_names = set(pipeline.dag.G)
parts = name.split("_")
base_name = "_".join(parts[:-1])
while name in step_names:
idx = int(name.split("_")[-1])
name = f"{base_name}_{idx+1}"
return name
class StepResources(RuntimeParametersMixin, BaseModel):
"""A class to define the resources assigned to a `_Step`.
Attributes:
replicas: The number of replicas for the step.
cpus: The number of CPUs assigned to each step replica.
gpus: The number of GPUs assigned to each step replica.
memory: The memory in bytes required for each step replica.
resources: A dictionary containing the number of custom resources required for
each step replica.
"""
replicas: RuntimeParameter[PositiveInt] = Field(
default=1, description="The number of replicas for the step."
)
cpus: Optional[RuntimeParameter[PositiveInt]] = Field(
default=None, description="The number of CPUs assigned to each step replica."
)
gpus: Optional[RuntimeParameter[PositiveInt]] = Field(
default=None, description="The number of GPUs assigned to each step replica."
)
memory: Optional[RuntimeParameter[PositiveInt]] = Field(
default=None, description="The memory in bytes required for each step replica."
)
resources: Optional[RuntimeParameter[Dict[str, int]]] = Field(
default=None,
description="A dictionary containing names of custom resources and the"
" number of those resources required for each step replica.",
)
class _Step(
RuntimeParametersMixin,
RequirementsMixin,
SignatureMixin,
BaseModel,
_Serializable,
ABC,
):
"""Base class for the steps that can be included in a `Pipeline`.
A `Step` is a class defining some processing logic. The input and outputs for this
processing logic are lists of dictionaries with the same keys:
```python
[
{"column1": "value1", "column2": "value2", ...},
{"column1": "value1", "column2": "value2", ...},
{"column1": "value1", "column2": "value2", ...},
]
```
The processing logic is defined in the `process` method, which depending on the
number of previous steps, can receive more than one list of dictionaries, each with
the output of the previous steps. In order to make `distilabel` know where the outputs
from the previous steps are, the `process` function from each `Step` must have an argument
or positional argument annotated with `StepInput`.
```python
class StepWithOnePreviousStep(Step):
def process(self, inputs: StepInput) -> StepOutput:
yield [...]
class StepWithSeveralPreviousStep(Step):
# mind the * to indicate that the argument is a list of StepInput
def process(self, *inputs: StepInput) -> StepOutput:
yield [...]
```
In order to perform static validations and to check that the chaining of the steps
in the pipeline is valid, a `Step` must also define the `inputs` and `outputs`
properties:
- `inputs`: a list of strings with the names of the columns that the step needs as
input. It can be an empty list if the step is a generator step.
- `outputs`: a list of strings with the names of the columns that the step will
produce as output.
Optionally, a `Step` can override the `load` method to perform any initialization
logic before the `process` method is called. For example, to load an LLM, stablish a
connection to a database, etc.
Finally, the `Step` class inherits from `pydantic.BaseModel`, so attributes can be easily
defined, validated, serialized and included in the `__init__` method of the step.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_default=True,
validate_assignment=True,
extra="forbid",
)
name: Optional[str] = Field(default=None, pattern=r"^[a-zA-Z0-9_-]+$")
resources: StepResources = StepResources()
pipeline: Any = Field(default=None, exclude=True, repr=False)
input_mappings: Dict[str, str] = {}
output_mappings: Dict[str, str] = {}
use_cache: bool = True
_pipeline_artifacts_path: Path = PrivateAttr(None)
_built_from_decorator: bool = PrivateAttr(default=False)
_logger: "Logger" = PrivateAttr(None)
def model_post_init(self, __context: Any) -> None:
from distilabel.pipeline.base import _GlobalPipelineManager
super().model_post_init(__context)
if self.pipeline is None:
self.pipeline = _GlobalPipelineManager.get_pipeline()
if self.pipeline is None:
_logger = logging.getLogger(f"distilabel.step.{self.name}")
_logger.warning(
f"Step '{self.name}' hasn't received a pipeline, and it hasn't been"
" created within a `Pipeline` context. Please, use"
" `with Pipeline() as pipeline:` and create the step within the context."
)
if not self.name:
# This must be done before the check for repeated names, but assuming
# we are passing the pipeline from the _GlobalPipelineManager, should
# be done after that.
self.name = _infer_step_name(type(self).__name__, self.pipeline)
if self.pipeline is not None:
# If not set an error will be raised in `Pipeline.run` parent
self.pipeline._add_step(self)
def connect(
self,
*steps: "_Step",
routing_batch_function: Optional["RoutingBatchFunction"] = None,
) -> None:
"""Connects the current step to another step in the pipeline, which means that
the output of this step will be the input of the other step.
Args:
steps: The steps to connect to the current step.
routing_batch_function: A function that receives a list of steps and returns
a list of steps to which the output batch generated by this step should be
routed. It should be used to define the routing logic of the pipeline. If
not provided, the output batch will be routed to all the connected steps.
Defaults to `None`.
"""
assert self.pipeline is not None
if routing_batch_function:
self._set_routing_batch_function(routing_batch_function)
for step in steps:
self.pipeline._add_edge(from_step=self.name, to_step=step.name) # type: ignore
def _set_routing_batch_function(
self, routing_batch_function: "RoutingBatchFunction"
) -> None:
"""Sets a routing batch function for the batches generated by this step, so they
get routed to specific downstream steps.
Args:
routing_batch_function: The routing batch function that will be used to route
the batches generated by this step.
"""
self.pipeline._add_routing_batch_function(
step_name=self.name, # type: ignore
routing_batch_function=routing_batch_function,
)
routing_batch_function._step = self
@overload
def __rshift__(self, other: "RoutingBatchFunction") -> "RoutingBatchFunction": ...
@overload
def __rshift__(
self, other: List["DownstreamConnectableSteps"]
) -> List["DownstreamConnectableSteps"]: ...
@overload
def __rshift__(self, other: "DownstreamConnectable") -> "DownstreamConnectable": ...
def __rshift__(
self,
other: Union[
"DownstreamConnectable",
"RoutingBatchFunction",
List["DownstreamConnectableSteps"],
],
) -> Union[
"DownstreamConnectable",
"RoutingBatchFunction",
List["DownstreamConnectableSteps"],
]:
"""Allows using the `>>` operator to connect steps in the pipeline.
Args:
other: The step to connect, a list of steps to connect to or a routing batch
function to be set for the step.
Returns:
The connected step, the list of connected steps or the routing batch function.
Example:
```python
step1 >> step2
# Would be equivalent to:
step1.connect(step2)
# It also allows to connect a list of steps
step1 >> [step2, step3]
```
"""
# Here to avoid circular imports
from distilabel.pipeline.routing_batch_function import RoutingBatchFunction
if isinstance(other, list):
self.connect(*other)
return other
if isinstance(other, RoutingBatchFunction):
self._set_routing_batch_function(other)
return other
self.connect(other)
return other
def __rrshift__(self, other: List["UpstreamConnectableSteps"]) -> Self:
"""Allows using the [step1, step2] >> step3 operator to connect a list of steps in the pipeline
to a single step, as the list doesn't have the __rshift__ operator.
Args:
other: The step to connect to.
Returns:
The connected step
Example:
```python
[step2, step3] >> step1
# Would be equivalent to:
step2.connect(step1)
step3.connect(step1)
```
"""
for o in other:
o.connect(self)
return self
def load(self) -> None:
"""Method to perform any initialization logic before the `process` method is
called. For example, to load an LLM, stablish a connection to a database, etc.
"""
self._logger = logging.getLogger(f"distilabel.step.{self.name}")
def unload(self) -> None:
"""Method to perform any cleanup logic after the `process` method is called. For
example, to close a connection to a database, etc.
"""
self._logger.debug("Executing step unload logic.")
@property
def is_generator(self) -> bool:
"""Whether the step is a generator step or not.
Returns:
`True` if the step is a generator step, `False` otherwise.
"""
return isinstance(self, GeneratorStep)
@property
def is_global(self) -> bool:
"""Whether the step is a global step or not.
Returns:
`True` if the step is a global step, `False` otherwise.
"""
return isinstance(self, GlobalStep)
@property
def is_normal(self) -> bool:
"""Whether the step is a normal step or not.
Returns:
`True` if the step is a normal step, `False` otherwise.
"""
return not self.is_generator and not self.is_global
@property
def inputs(self) -> "StepColumns":
"""List of strings with the names of the mandatory columns that the step needs as
input or dictionary in which the keys are the input columns of the step and the
values are booleans indicating whether the column is optional or not.
Returns:
List of strings with the names of the columns that the step needs as input.
"""
return []
@property
def outputs(self) -> "StepColumns":
"""List of strings with the names of the columns that the step will produce as
output or dictionary in which the keys are the output columns of the step and the
values are booleans indicating whether the column is optional or not.
Returns:
List of strings with the names of the columns that the step will produce as
output.
"""
return []
@cached_property
def process_parameters(self) -> List[inspect.Parameter]:
"""Returns the parameters of the `process` method of the step.
Returns:
The parameters of the `process` method of the step.
"""
return list(inspect.signature(self.process).parameters.values()) # type: ignore
def has_multiple_inputs(self) -> bool:
"""Whether the `process` method of the step receives more than one input or not
i.e. has a `*` argument annotated with `StepInput`.
Returns:
`True` if the `process` method of the step receives more than one input,
`False` otherwise.
"""
return any(
param.kind == param.VAR_POSITIONAL for param in self.process_parameters
)
def get_process_step_input(self) -> Union[inspect.Parameter, None]:
"""Returns the parameter of the `process` method of the step annotated with
`StepInput`.
Returns:
The parameter of the `process` method of the step annotated with `StepInput`,
or `None` if there is no parameter annotated with `StepInput`.
Raises:
TypeError: If the step has more than one parameter annotated with `StepInput`.
"""
step_input_parameter = None
for parameter in self.process_parameters:
if is_parameter_annotated_with(parameter, _STEP_INPUT_ANNOTATION):
if step_input_parameter is not None:
raise DistilabelTypeError(
f"Step '{self.name}' should have only one parameter with type"
" hint `StepInput`.",
page="sections/how_to_guides/basic/step/#defining-custom-steps",
)
step_input_parameter = parameter
return step_input_parameter
def verify_inputs_mappings(self) -> None:
"""Verifies that the `inputs_mappings` of the step are valid i.e. the input
columns exist in the inputs of the step.
Raises:
ValueError: If the `inputs_mappings` of the step are not valid.
"""
if not self.input_mappings:
return
for input in self.input_mappings:
if input not in self.inputs:
raise DistilabelUserError(
f"The input column '{input}' doesn't exist in the inputs of the"
f" step '{self.name}'. Inputs of the step are: {self.inputs}."
" Please, review the `inputs_mappings` argument of the step.",
page="sections/how_to_guides/basic/step/#arguments",
)
def verify_outputs_mappings(self) -> None:
"""Verifies that the `outputs_mappings` of the step are valid i.e. the output
columns exist in the outputs of the step.
Raises:
ValueError: If the `outputs_mappings` of the step are not valid.
"""
if not self.output_mappings:
return
for output in self.output_mappings:
if output not in self.outputs:
raise DistilabelUserError(
f"The output column '{output}' doesn't exist in the outputs of the"
f" step '{self.name}'. Outputs of the step are: {self.outputs}."
" Please, review the `outputs_mappings` argument of the step.",
page="sections/how_to_guides/basic/step/#arguments",
)
def get_inputs(self) -> Dict[str, bool]:
"""Gets the inputs of the step after the `input_mappings`. This method is meant
to be used to run validations on the inputs of the step.
Returns:
The inputs of the step after the `input_mappings` and if they are required or
not.
"""
if isinstance(self.inputs, list):
return {
self.input_mappings.get(input, input): True for input in self.inputs
}
return {
self.input_mappings.get(input, input): required
for input, required in self.inputs.items()
}
def get_outputs(self) -> Dict[str, bool]:
"""Gets the outputs of the step after the `outputs_mappings`. This method is
meant to be used to run validations on the outputs of the step.
Returns:
The outputs of the step after the `outputs_mappings` and if they are required
or not.
"""
if isinstance(self.outputs, list):
return {
self.output_mappings.get(output, output): True
for output in self.outputs
}
return {
self.output_mappings.get(output, output): required
for output, required in self.outputs.items()
}
def set_pipeline_artifacts_path(self, path: Path) -> None:
"""Sets the `_pipeline_artifacts_path` attribute. This method is meant to be used
by the `Pipeline` once the cache location is known.
Args:
path: the path where the artifacts generated by the pipeline steps should be
saved.
"""
self._pipeline_artifacts_path = path
@property
def artifacts_directory(self) -> Union[Path, None]:
"""Gets the path of the directory where the step should save its generated artifacts.
Returns:
The path of the directory where the step should save the generated artifacts,
or `None` if `_pipeline_artifacts_path` is not set.
"""
if self._pipeline_artifacts_path is None:
return None
return self._pipeline_artifacts_path / self.name # type: ignore
def save_artifact(
self,
name: str,
write_function: Callable[[Path], None],
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Saves an artifact generated by the `Step`.
Args:
name: the name of the artifact.
write_function: a function that will receive the path where the artifact should
be saved.
metadata: the artifact metadata. Defaults to `None`.
"""
if self.artifacts_directory is None:
self._logger.warning(
f"Cannot save artifact with '{name}' as `_pipeline_artifacts_path` is not"
" set. This is normal if the `Step` is being executed as a standalone component."
)
return
artifact_directory_path = self.artifacts_directory / name
artifact_directory_path.mkdir(parents=True, exist_ok=True)
self._logger.info(f"🏺 Storing '{name}' generated artifact...")
self._logger.debug(
f"Calling `write_function` to write artifact in '{artifact_directory_path}'..."
)
write_function(artifact_directory_path)
metadata_path = artifact_directory_path / "metadata.json"
self._logger.debug(
f"Calling `write_json` to write artifact metadata in '{metadata_path}'..."
)
write_json(filename=metadata_path, data=metadata or {})
def impute_step_outputs(
self, step_output: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Imputes the output columns of the step that are not present in the step output.
"""
result = []
for row in step_output:
data = row.copy()
for output in self.get_outputs().keys():
data[output] = None
result.append(data)
return result
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
dump = super()._model_dump(obj, **kwargs)
dump["runtime_parameters_info"] = self.get_runtime_parameters_info()
return dump
class Step(_Step, ABC):
"""Base class for the steps that can be included in a `Pipeline`.
Attributes:
input_batch_size: The number of rows that will contain the batches processed by
the step. Defaults to `50`.
Runtime parameters:
- `input_batch_size`: The number of rows that will contain the batches processed
by the step. Defaults to `50`.
"""
input_batch_size: RuntimeParameter[PositiveInt] = Field(
default=DEFAULT_INPUT_BATCH_SIZE,
description="The number of rows that will contain the batches processed by the"
" step.",
)
@abstractmethod
def process(self, *inputs: StepInput) -> "StepOutput":
"""Method that defines the processing logic of the step. It should yield the
output rows.
Args:
*inputs: An argument used to receive the outputs of the previous steps. The
number of arguments depends on the number of previous steps. It doesn't
need to be an `*args` argument, it can be a regular argument annotated
with `StepInput` if the step has only one previous step.
"""
pass
def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput":
"""Runs the `process` method of the step applying the `input_mappings` to the input
rows and the `outputs_mappings` to the output rows. This is the function that
should be used to run the processing logic of the step.
Yields:
The output rows.
"""
inputs, overriden_inputs = (
self._apply_input_mappings(args)
if self.input_mappings
else (args, [{} for _ in range(len(args[0]))])
)
# If the `Step` was built using the `@step` decorator, then we need to pass
# the runtime parameters as kwargs, so they can be used within the processing
# function
generator = (
self.process(*inputs)
if not self._built_from_decorator
else self.process(*inputs, **self._runtime_parameters)
)
for output_rows in generator:
restored = []
for i, row in enumerate(output_rows):
# Correct the index here because we don't know the num_generations from the llm
# ahead of time. For example, if we have `len(overriden_inputs)==5` and `len(row)==10`,
# from `num_generations==2` and `group_generations=False` in the LLM:
# The loop will use indices 0, 1, 2, 3, 4, 0, 1, 2, 3, 4
ntimes_i = i % len(overriden_inputs)
restored.append(
self._apply_mappings_and_restore_overriden(
row, overriden_inputs[ntimes_i]
)
)
yield restored
def _apply_input_mappings(
self, inputs: Tuple[List[Dict[str, Any]], ...]
) -> Tuple[Tuple[List[Dict[str, Any]], ...], List[Dict[str, Any]]]:
"""Applies the `input_mappings` to the input rows.
Args:
inputs: The input rows.
Returns:
The input rows with the `input_mappings` applied and the overriden values
that were replaced by the `input_mappings`.
"""
reverted_input_mappings = {v: k for k, v in self.input_mappings.items()}
renamed_inputs = []
overriden_inputs = []
for i, row_inputs in enumerate(inputs):
renamed_row_inputs = []
for row in row_inputs:
overriden_keys = {}
renamed_row = {}
for k, v in row.items():
renamed_key = reverted_input_mappings.get(k, k)
if renamed_key not in renamed_row or k != renamed_key:
renamed_row[renamed_key] = v
if k != renamed_key and renamed_key in row and len(inputs) == 1:
overriden_keys[renamed_key] = row[renamed_key]
if i == 0:
overriden_inputs.append(overriden_keys)
renamed_row_inputs.append(renamed_row)
renamed_inputs.append(renamed_row_inputs)
return tuple(renamed_inputs), overriden_inputs
def _apply_mappings_and_restore_overriden(
self, row: Dict[str, Any], overriden: Dict[str, Any]
) -> Dict[str, Any]:
"""Reverts the `input_mappings` applied to the input rows and applies the `output_mappings`
to the output rows. In addition, it restores the overriden values that were replaced
by the `input_mappings`.
Args:
row: The output row.
overriden: The overriden values that were replaced by the `input_mappings`.
Returns:
The output row with the `output_mappings` applied and the overriden values
restored.
"""
result = {}
for k, v in row.items():
mapped_key = (
self.output_mappings.get(k, None)
or self.input_mappings.get(k, None)
or k
)
result[mapped_key] = v
# Restore overriden values
for k, v in overriden.items():
if k not in result:
result[k] = v
return result
class GeneratorStep(_Step, ABC):
"""A special kind of `Step` that is able to generate data i.e. it doesn't receive
any input from the previous steps.
Attributes:
batch_size: The number of rows that will contain the batches generated by the
step. Defaults to `50`.
Runtime parameters:
- `batch_size`: The number of rows that will contain the batches generated by
the step. Defaults to `50`.
"""
batch_size: RuntimeParameter[int] = Field(
default=50,
description="The number of rows that will contain the batches generated by the"
" step.",
)
@abstractmethod
def process(self, offset: int = 0) -> "GeneratorStepOutput":
"""Method that defines the generation logic of the step. It should yield the
output rows and a boolean indicating if it's the last batch or not.
Args:
offset: The offset to start the generation from. Defaults to 0.
Yields:
The output rows and a boolean indicating if it's the last batch or not.
"""
pass
def process_applying_mappings(self, offset: int = 0) -> "GeneratorStepOutput":
"""Runs the `process` method of the step applying the `outputs_mappings` to the
output rows. This is the function that should be used to run the generation logic
of the step.
Args:
offset: The offset to start the generation from. Defaults to 0.
Yields:
The output rows and a boolean indicating if it's the last batch or not.
"""
# If the `Step` was built using the `@step` decorator, then we need to pass
# the runtime parameters as `kwargs`, so they can be used within the processing
# function
generator = (
self.process(offset=offset)
if not self._built_from_decorator
else self.process(offset=offset, **self._runtime_parameters)
)
for output_rows, last_batch in generator:
yield (
[
{self.output_mappings.get(k, k): v for k, v in row.items()}
for row in output_rows
],
last_batch,
)
class GlobalStep(Step, ABC):
"""A special kind of `Step` which it's `process` method receives all the data processed
by their previous steps at once, instead of receiving it in batches. This kind of steps
are useful when the processing logic requires to have all the data at once, for example
to train a model, to perform a global aggregation, etc.
"""
@property
def inputs(self) -> "StepColumns":
return []
@property
def outputs(self) -> "StepColumns":
return []
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
from typing import TYPE_CHECKING, Any, List, Optional
import numpy as np
from pydantic import Field, PrivateAttr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps import (
GlobalStep,
StepInput,
)
if TYPE_CHECKING:
from sklearn.cluster import DBSCAN as _DBSCAN
from distilabel.typing import StepOutput
class DBSCAN(GlobalStep):
r"""DBSCAN (Density-Based Spatial Clustering of Applications with Noise) finds core
samples in regions of high density and expands clusters from them. This algorithm
is good for data which contains clusters of similar density.
This is a `GlobalStep` that clusters the embeddings using the DBSCAN algorithm
from `sklearn`. Visit `TextClustering` step for an example of use.
The trained model is saved as an artifact when creating a distiset
and pushing it to the Hugging Face Hub.
Input columns:
- projection (`List[float]`): Vector representation of the text to cluster,
normally the output from the `UMAP` step.
Output columns:
- cluster_label (`int`): Integer representing the label of a given cluster. -1
means it wasn't clustered.
Categories:
- clustering
- text-classification
References:
- [`DBSCAN demo of sklearn`](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html#demo-of-dbscan-clustering-algorithm)
- [`sklearn dbscan`](https://scikit-learn.org/stable/modules/clustering.html#dbscan)
Attributes:
- eps: The maximum distance between two samples for one to be considered as in the
neighborhood of the other. This is not a maximum bound on the distances of
points within a cluster. This is the most important DBSCAN parameter to
choose appropriately for your data set and distance function.
- min_samples: The number of samples (or total weight) in a neighborhood for a point
to be considered as a core point. This includes the point itself. If `min_samples`
is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
to a lower value, the found clusters will be more sparse.
- metric: The metric to use when calculating distance between instances in a feature
array. If metric is a string or callable, it must be one of the options allowed
by `sklearn.metrics.pairwise_distances` for its metric parameter.
- n_jobs: The number of parallel jobs to run.
Runtime parameters:
- `eps`: The maximum distance between two samples for one to be considered as in the
neighborhood of the other. This is not a maximum bound on the distances of
points within a cluster. This is the most important DBSCAN parameter to
choose appropriately for your data set and distance function.
- `min_samples`: The number of samples (or total weight) in a neighborhood for a point
to be considered as a core point. This includes the point itself. If `min_samples`
is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
to a lower value, the found clusters will be more sparse.
- `metric`: The metric to use when calculating distance between instances in a feature
array. If metric is a string or callable, it must be one of the options allowed
by `sklearn.metrics.pairwise_distances` for its metric parameter.
- `n_jobs`: The number of parallel jobs to run.
"""
eps: Optional[RuntimeParameter[float]] = Field(
default=0.3,
description=(
"The maximum distance between two samples for one to be considered "
"as in the neighborhood of the other. This is not a maximum bound "
"on the distances of points within a cluster. This is the most "
"important DBSCAN parameter to choose appropriately for your data set "
"and distance function."
),
)
min_samples: Optional[RuntimeParameter[int]] = Field(
default=30,
description=(
"The number of samples (or total weight) in a neighborhood for a point to "
"be considered as a core point. This includes the point itself. If "
"`min_samples` is set to a higher value, DBSCAN will find denser clusters, "
"whereas if it is set to a lower value, the found clusters will be more "
"sparse."
),
)
metric: Optional[RuntimeParameter[str]] = Field(
default="euclidean",
description=(
"The metric to use when calculating distance between instances in a "
"feature array. If metric is a string or callable, it must be one of "
"the options allowed by `sklearn.metrics.pairwise_distances` for "
"its metric parameter."
),
)
n_jobs: Optional[RuntimeParameter[int]] = Field(
default=8, description="The number of parallel jobs to run."
)
_clusterer: Optional["_DBSCAN"] = PrivateAttr(None)
def load(self) -> None:
super().load()
if importlib.util.find_spec("sklearn") is None:
raise ImportError(
"`sklearn` package is not installed. Please install it using `pip install 'distilabel[text-clustering]'`."
)
from sklearn.cluster import DBSCAN as _DBSCAN
self._clusterer = _DBSCAN(
eps=self.eps,
min_samples=self.min_samples,
metric=self.metric,
n_jobs=self.n_jobs,
)
def unload(self) -> None:
self._clusterer = None
@property
def inputs(self) -> List[str]:
return ["projection"]
@property
def outputs(self) -> List[str]:
return ["cluster_label"]
def _save_model(self, model: Any) -> None:
import joblib
def save_model(path):
with open(str(path / "DBSCAN.joblib"), "wb") as f:
joblib.dump(model, f)
self.save_artifact(
name="DBSCAN_model",
write_function=lambda path: save_model(path),
metadata={
"eps": self.eps,
"min_samples": self.min_samples,
"metric": self.metric,
},
)
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
projections = np.array([input["projection"] for input in inputs])
self._logger.info("🏋️‍♀️ Start training DBSCAN...")
fitted_clusterer = self._clusterer.fit(projections)
cluster_labels = fitted_clusterer.labels_
# Sets the cluster labels for each input, -1 means it wasn't clustered
for input, cluster_label in zip(inputs, cluster_labels):
input["cluster_label"] = cluster_label
self._logger.info(f"DBSCAN labels assigned: {len(set(cluster_labels))}")
self._save_model(fitted_clusterer)
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import json
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from pydantic import Field
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps import StepInput
from distilabel.steps.tasks import TextClassification
from distilabel.steps.tasks.base import GlobalTask
from distilabel.utils.itertools import batched
if TYPE_CHECKING:
from distilabel.typing import StepOutput
class TextClustering(TextClassification, GlobalTask):
"""Task that clusters a set of texts and generates summary labels for each cluster.
This is a `GlobalTask` that inherits from `TextClassification`, this means that all
the attributes from that class are available here. Also, in this case we deal
with all the inputs at once, instead of using batches. The `input_batch_size` is
used here to send the examples to the LLM in batches (a subtle difference with the
more common `Task` definitions).
The task looks in each cluster for a given number of representative examples (the number
is set by the `samples_per_cluster` attribute), and sends them to the LLM to get a label/s
that represent the cluster. The labels are then assigned to each text in the cluster.
The clusters and projections used in the step, are assumed to be obtained from the `UMAP`
+ `DBSCAN` steps, but could be generated for similar steps, as long as they represent the
same concepts.
This step runs a pipeline like the one in this repository:
https://github.com/huggingface/text-clustering
Input columns:
- text (`str`): The reference text we want to obtain labels for.
- projection (`List[float]`): Vector representation of the text to cluster,
normally the output from the `UMAP` step.
- cluster_label (`int`): Integer representing the label of a given cluster. -1
means it wasn't clustered.
Output columns:
- summary_label (`str`): The label or list of labels for the text.
- model_name (`str`): The name of the model used to generate the label/s.
Categories:
- clustering
- text-classification
References:
- [`text-clustering repository`](https://github.com/huggingface/text-clustering)
Attributes:
- savefig: Whether to generate and save a figure with the clustering of the texts.
- samples_per_cluster: The number of examples to use in the LLM as a sample of the cluster.
Examples:
Generate labels for a set of texts using clustering:
```python
from distilabel.models import InferenceEndpointsLLM
from distilabel.steps import UMAP, DBSCAN, TextClustering
from distilabel.pipeline import Pipeline
ds_name = "argilla-warehouse/personahub-fineweb-edu-4-clustering-100k"
with Pipeline(name="Text clustering dataset") as pipeline:
batch_size = 500
ds = load_dataset(ds_name, split="train").select(range(10000))
loader = make_generator_step(ds, batch_size=batch_size, repo_id=ds_name)
umap = UMAP(n_components=2, metric="cosine")
dbscan = DBSCAN(eps=0.3, min_samples=30)
text_clustering = TextClustering(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
n=3, # 3 labels per example
query_title="Examples of Personas",
samples_per_cluster=10,
context=(
"Describe the main themes, topics, or categories that could describe the "
"following types of personas. All the examples of personas must share "
"the same set of labels."
),
default_label="None",
savefig=True,
input_batch_size=8,
input_mappings={"text": "persona"},
use_default_structured_output=True,
)
loader >> umap >> dbscan >> text_clustering
```
"""
savefig: Optional[RuntimeParameter[bool]] = Field(
default=True,
description="Whether to generate and save a figure with the clustering of the texts.",
)
samples_per_cluster: int = Field(
default=10,
description="The number of examples to use in the LLM as a sample of the cluster.",
)
@property
def inputs(self) -> List[str]:
"""The input for the task are the same as those for `TextClassification` plus
the `projection` and `cluster_label` columns (which can be obtained from
UMAP + DBSCAN steps).
"""
return super().inputs + ["projection", "cluster_label"]
@property
def outputs(self) -> List[str]:
"""The output for the task is the `summary_label` and the `model_name`."""
return ["summary_label", "model_name"]
def load(self) -> None:
super().load()
if self.savefig and (importlib.util.find_spec("matplotlib") is None):
raise ImportError(
"`matplotlib` package is not installed. Please install it using `pip install matplotlib`."
)
def _save_figure(
self,
data: pd.DataFrame,
cluster_centers: Dict[str, Tuple[float, float]],
cluster_summaries: Dict[int, str],
) -> None:
"""Saves the figure starting from the dataframe, using matplotlib.
Args:
data: pd.DataFrame with the columns 'X', 'Y' and 'labels' representing
the projections and the label of each text respectively.
cluster_centers: Dictionary mapping from each label the center of a cluster,
to help with the placement of the annotations.
cluster_summaries: The summaries of the clusters, obtained from the LLM.
"""
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 8), dpi=300)
unique_labels = data["labels"].unique()
# Map of colors for each label (-1 is black)
colormap = dict(
zip(unique_labels, plt.cm.Spectral(np.linspace(0, 1, len(unique_labels))))
)
colormap[-1] = np.array([0, 0, 0, 0])
data["color"] = data["labels"].map(colormap)
data.plot(
kind="scatter",
x="X",
y="Y",
c="color",
s=0.75,
alpha=0.8,
linewidth=0.4,
ax=ax,
colorbar=False,
)
for label in cluster_summaries.keys():
if label == -1:
continue
summary = str(cluster_summaries[label]) # These are obtained from the LLM
position = cluster_centers[label]
t = ax.text(
position[0],
position[1],
summary,
horizontalalignment="center",
verticalalignment="center",
fontsize=4,
)
t.set_bbox(
{
"facecolor": "white",
"alpha": 0.9,
"linewidth": 0,
"boxstyle": "square,pad=0.1",
}
)
ax.set_axis_off()
# Save the plot as an artifact of the step
self.save_artifact(
name="Text clusters",
write_function=lambda path: fig.savefig(path / "figure_clustering.png"),
metadata={"type": "image", "library": "matplotlib"},
)
plt.close()
def _create_figure(
self,
inputs: StepInput,
label2docs: Dict[int, List[str]],
cluster_summaries: Dict[int, str],
) -> None:
"""Creates a figure of the clustered texts and save it as an artifact.
Args:
inputs: The inputs of the step, as we will extract information from them again.
label2docs: Map from each label to the list of documents (texts) that belong to that cluster.
cluster_summaries: The summaries of the clusters, obtained from the LLM.
"""
self._logger.info("🖼️ Creating figure for the clusters...")
labels = []
projections = []
id2cluster = {}
for i, input in enumerate(inputs):
label = input["cluster_label"]
id2cluster[i] = label
labels.append(label)
projections.append(input["projection"])
projections = np.array(projections)
# Contains the placement of the cluster centers in the figure
cluster_centers: Dict[str, Tuple[float, float]] = {}
for label in label2docs.keys():
x = np.mean([projections[doc, 0] for doc in label2docs[label]])
y = np.mean([projections[doc, 1] for doc in label2docs[label]])
cluster_centers[label] = (x, y)
df = pd.DataFrame(
data={
"X": projections[:, 0],
"Y": projections[:, 1],
"labels": labels,
}
)
self._save_figure(
df, cluster_centers=cluster_centers, cluster_summaries=cluster_summaries
)
def _prepare_input_texts(
self,
inputs: StepInput,
label2docs: Dict[int, List[int]],
unique_labels: List[int],
) -> List[Dict[str, Union[str, int]]]:
"""Prepares a batch of inputs to send to the LLM, with the examples of each cluster.
Args:
inputs: Inputs from the step.
label2docs: Map from each label to the list of documents (texts) that
belong to that cluster.
unique_labels: The unique labels of the clusters.
Returns:
The input texts to send to the LLM, with the examples of each cluster
prepared to be used in the prompt, and an additional key to store the
labels (that will be needed to find the data after the batches are
returned from the LLM).
"""
input_texts = []
for label in range(unique_labels): # The label -1 is implicitly excluded
# Get the ids but remove possible duplicates, which could happen with bigger probability
# the bigger the number of examples requested, and the smaller the subset of examples
ids = set(
np.random.choice(label2docs[label], size=self.samples_per_cluster)
) # Grab the number of examples
examples = [inputs[i]["text"] for i in ids]
input_text = {
"text": "\n\n".join(
[f"Example {i}:\n{t}" for i, t in enumerate(examples, start=1)]
),
"__LABEL": label,
}
input_texts.append(input_text)
return input_texts
def process(self, inputs: StepInput) -> "StepOutput":
labels = [input["cluster_label"] for input in inputs]
# -1 because -1 is the label for the unclassified
unique_labels = len(set(labels)) - 1
# This will be the output of the LLM, the set of labels for each cluster
cluster_summaries: Dict[int, str] = {-1: self.default_label}
# Map from label to list of documents, will use them to select examples from each cluster
label2docs = defaultdict(list)
for i, label in enumerate(labels):
label2docs[label].append(i)
input_texts = self._prepare_input_texts(inputs, label2docs, unique_labels)
# Send the texts in batches to the LLM, and get the labels for each cluster
for i, batched_inputs in enumerate(batched(input_texts, self.input_batch_size)):
self._logger.info(f"📦 Processing internal batch of inputs {i}...")
results = super().process(batched_inputs)
for result in next(results): # Extract the elements from the generator
cluster_summaries[result["__LABEL"]] = result["labels"]
# Assign the labels to each text
for input in inputs:
input["summary_label"] = json.dumps(
cluster_summaries[input["cluster_label"]]
)
if self.savefig:
self._create_figure(inputs, label2docs, cluster_summaries)
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
from typing import TYPE_CHECKING, Any, List, Optional
import numpy as np
from pydantic import Field, PrivateAttr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps import (
GlobalStep,
StepInput,
)
if TYPE_CHECKING:
from umap import UMAP as _UMAP
from distilabel.typing import StepOutput
class UMAP(GlobalStep):
r"""UMAP is a general purpose manifold learning and dimension reduction algorithm.
This is a `GlobalStep` that reduces the dimensionality of the embeddings using. Visit
the `TextClustering` step for an example of use. The trained model is saved as an artifact
when creating a distiset and pushing it to the Hugging Face Hub.
Input columns:
- embedding (`List[float]`): The original embeddings we want to reduce the dimension.
Output columns:
- projection (`List[float]`): Embedding reduced to the number of components specified,
the size of the new embeddings will be determined by the `n_components`.
Categories:
- clustering
- text-classification
References:
- [`UMAP repository`](https://github.com/lmcinnes/umap/tree/master)
- [`UMAP documentation`](https://umap-learn.readthedocs.io/en/latest/)
Attributes:
- n_components: The dimension of the space to embed into. This defaults to 2 to
provide easy visualization (that's probably what you want), but can
reasonably be set to any integer value in the range 2 to 100.
- metric: The metric to use to compute distances in high dimensional space.
Visit UMAP's documentation for more information. Defaults to `euclidean`.
- n_jobs: The number of parallel jobs to run. Defaults to `8`.
- random_state: The random state to use for the UMAP algorithm.
Runtime parameters:
- `n_components`: The dimension of the space to embed into. This defaults to 2 to
provide easy visualization (that's probably what you want), but can
reasonably be set to any integer value in the range 2 to 100.
- `metric`: The metric to use to compute distances in high dimensional space.
Visit UMAP's documentation for more information. Defaults to `euclidean`.
- `n_jobs`: The number of parallel jobs to run. Defaults to `8`.
- `random_state`: The random state to use for the UMAP algorithm.
Citations:
```
@misc{mcinnes2020umapuniformmanifoldapproximation,
title={UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction},
author={Leland McInnes and John Healy and James Melville},
year={2020},
eprint={1802.03426},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/1802.03426},
}
```
"""
n_components: Optional[RuntimeParameter[int]] = Field(
default=2,
description=(
"The dimension of the space to embed into. This defaults to 2 to "
"provide easy visualization, but can reasonably be set to any "
"integer value in the range 2 to 100."
),
)
metric: Optional[RuntimeParameter[str]] = Field(
default="euclidean",
description=(
"The metric to use to compute distances in high dimensional space. "
"Visit UMAP's documentation for more information."
),
)
n_jobs: Optional[RuntimeParameter[int]] = Field(
default=8, description="The number of parallel jobs to run."
)
random_state: Optional[RuntimeParameter[int]] = Field(
default=None, description="The random state to use for the UMAP algorithm."
)
_umap: Optional["_UMAP"] = PrivateAttr(None)
def load(self) -> None:
super().load()
if importlib.util.find_spec("umap") is None:
raise ImportError(
"`umap` package is not installed. Please install it using `pip install 'distilabel[text-clustering]'`."
)
from umap import UMAP as _UMAP
self._umap = _UMAP(
n_components=self.n_components,
metric=self.metric,
n_jobs=self.n_jobs,
random_state=self.random_state,
)
def unload(self) -> None:
self._umap = None
@property
def inputs(self) -> List[str]:
return ["embedding"]
@property
def outputs(self) -> List[str]:
return ["projection"]
def _save_model(self, model: Any) -> None:
import joblib
def save_model(path):
with open(str(path / "UMAP.joblib"), "wb") as f:
joblib.dump(model, f)
self.save_artifact(
name="UMAP_model",
write_function=lambda path: save_model(path),
metadata={
"n_components": self.n_components,
"metric": self.metric,
},
)
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
# Shape of the embeddings is (n_samples, n_features)
embeddings = np.array([input["embedding"] for input in inputs])
self._logger.info("🏋️‍♀️ Start UMAP training...")
mapper = self._umap.fit(embeddings)
# Shape of the projection will be (n_samples, n_components)
for input, projection in zip(inputs, mapper.embedding_):
input["projection"] = projection
self._save_model(mapper)
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment