Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
from distilabel.constants import (
RECEIVES_ROUTED_BATCHES_ATTR_NAME,
STEP_ATTR_NAME,
)
from distilabel.pipeline._dag import DAG
from distilabel.pipeline.batch import _Batch
from distilabel.steps.base import _Step
from distilabel.utils.files import list_files_in_dir
from distilabel.utils.serialization import (
StrOrPath,
_check_is_dir,
_Serializable,
read_json,
)
if TYPE_CHECKING:
from distilabel.utils.serialization import StrOrPath
@dataclass
class _BatchManagerStep(_Serializable):
"""A class that will accumulate data for a step from the predecessors and create
batches for the step to process when there is enough data.
Attributes:
step_name: The name of the step that will process the data.
accumulate: A flag to indicate if the data should be accumulated and create a
batch with all the data received from the predecessors instead of creating
batches with the `input_batch_size`.
input_batch_size: The size of the batch to be created for the step to process.
If `None`, then `accumulate` must be `True`. Defaults to `None`.
data: A dictionary with the predecessor step name as the key and a list of
dictionaries (rows) as the value.
built_batches: A list with the batches that were built and sent to the step queue,
but the step was stopped before processing the batch, so the batch doesn't get
lost. Defaults to an empty list.
seq_no: The sequence number of the next batch to be created. It will be
incremented for each batch created.
last_batch_received: A list with the names of the steps that sent the last
batch of data.
convergence_step: A flag to indicate if the step is a convergence step. An
`Step` is a convergence step if all its predecessors are receiving routed
batches. Defaults to `False`.
convergence_step_batches_consumed: A dictionary in which the key is the `seq_no`
of the batch created by step A, that was used by step B and C and obtained from
the `created_from` of the batches created by them. It's used to know if all
the batches from B and C steps created from batches of A have been consumed
by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`.
Defaults to an empty dictionary.
next_expected_created_from_batch_seq_no: The next expected sequence number of the
batch from step A used by steps B and C and obtained from the `created_from`
of the batches created by them. It's used to avoid messing up the order of the
batches. Only used if `convergence_step=True`. Defaults to `0`.
step_signature: The signature that defines a given `Step`. It will be used for the
caching mechanism.
use_cache: Flag from the original `Step` to indicate whether this step should make use of
the cached data.
step_offset: Dictionary with each key the predecessor/s step/s and as value a dict
with keys `batch` and `offset`, containing the name of the file for the corresponding
batch, and the number of rows that were read from that step, respectively. Used
for caching mechanism.
"""
step_name: str
accumulate: bool
input_batch_size: Union[int, None] = None
data: Dict[str, List[_Batch]] = field(default_factory=dict)
built_batches: List[_Batch] = field(default_factory=list)
seq_no: int = 0
last_batch_received: List[str] = field(default_factory=list)
convergence_step: bool = False
convergence_step_batches_consumed: Dict[str, Dict[str, int]] = field(
default_factory=dict
)
next_expected_created_from_batch_seq_no: int = 0
next_expected_seq_no: Dict[str, Tuple[int, int]] = field(default_factory=dict)
step_signature: Optional[str] = None
use_cache: bool = False
step_offset: Dict[str, Tuple[int, int]] = field(default_factory=dict)
def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
"""Add a batch of data from `batch.step_name` to the step. It will accumulate the
data and keep track of the last batch received from the predecessors.
Args:
batch: The output batch of an step to be processed by the step.
prepend: If `True`, the content of the batch will be added to the `built_batches`
list. This is done so if a `_Batch` was already built and send to the step
queue, and the step is stopped before processing the batch, the batch doesn't
get lost. Defaults to `False`.
"""
from_step = batch.step_name
if prepend:
self.built_batches.append(batch)
else:
self.data[from_step].append(batch)
self.data[from_step].sort(key=lambda batch: batch.seq_no)
if batch.last_batch:
self.last_batch_received.append(from_step)
def get_batch(self) -> Union[_Batch, None]:
"""Create a new batch of data for the step to process. It will return `None` if
there is not enough data to create a batch.
Returns:
A `_Batch` instance if there is enough data to create a batch. Otherwise,
`None`.
"""
# If there are batches in the `built_batches` list, then return the first one
# and remove it from the list.
if self.built_batches:
return self.built_batches.pop(0)
if not self._ready_to_create_batch():
return None
seq_no = self._get_seq_no()
# `_last_batch` must be called before `_get_data`, as `_get_data` will update the
# list of data which is used to determine if the batch to be created is the last one.
last_batch = self._last_batch()
# Get the batch data and the information from which batches of the upstream steps
# the data was taken.
data, created_from, batch_routed_to = self._get_data()
# Update the step offset i.e. which is the last batch and last row index from that
# batch that the step has consumed
self._update_offset(created_from)
return _Batch(
seq_no=seq_no,
step_name=self.step_name,
last_batch=last_batch,
data=data,
accumulated=self.accumulate,
created_from=created_from,
batch_routed_to=batch_routed_to,
)
def empty_buffers(self) -> List[str]:
"""Checks if the input buffer for the step is empty.
Returns:
The name of the previous steps for which the input buffer for this step is
empty.
"""
if self.accumulate:
return [
previous_step
for previous_step in self.data.keys()
if previous_step not in self.last_batch_received
]
return [
previous_step
for previous_step, batches in self.data.items()
if previous_step not in self.last_batch_received
and sum(len(batch.data[0]) for batch in batches) < self.input_batch_size # type: ignore
]
def set_next_expected_seq_no(
self, from_step: str, next_expected_seq_no: int
) -> None:
"""Sets the next expected sequence number of a `_Batch` received by the step coming
from `from_step`.
Args:
from_step: The name of the step from which its next expected sequence number
in step has to be updated.
next_expected_seq_no: the next expected sequence number of a `_Batch` coming
from `from_step`.
"""
if not self.data[from_step] or (
self.data[from_step]
and self.data[from_step][0].seq_no >= next_expected_seq_no
):
self.next_expected_seq_no[from_step] = (
next_expected_seq_no,
next_expected_seq_no,
)
else:
self.next_expected_seq_no[from_step] = (
self.next_expected_seq_no[from_step][0],
next_expected_seq_no,
)
@classmethod
def from_step(
cls, step: "_Step", predecessors: Iterable[str], convergence_step: bool = False
) -> "_BatchManagerStep":
"""Creates a `_BatchManagerStep` instance from a `_Step` instance and its
predecessors.
Args:
step: The `_Step` instance.
predecessors: The names of the predecessors of the step.
convergence_step: A flag to indicate if the step is a convergence step. An
`Step` is a convergence step if all its predecessors are receiving routed
batches. Defaults to `False`.
Returns:
A `_BatchManagerStep` instance.
"""
return cls(
step_name=step.name, # type: ignore
accumulate=step.is_global,
input_batch_size=getattr(step, "input_batch_size", None),
data={predecessor: [] for predecessor in predecessors},
convergence_step=convergence_step,
next_expected_seq_no={predecessor: (0, 0) for predecessor in predecessors},
step_signature=step.signature,
use_cache=step.use_cache,
step_offset={predecessor: (0, 0) for predecessor in predecessors},
)
def _get_seq_no(self) -> int:
"""Gets the sequence number for the next batch to be created and increments it.
Returns:
The sequence number for the next batch to be created.
"""
seq_no = self.seq_no
self.seq_no += 1
return seq_no
def _get_data(
self,
) -> Tuple[
List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]], List[str]
]:
"""Gets the data needed to create a batch for the step to process. If the step is
accumulating data, then it will return a list with all the data received from the
predecessors. Otherwise, it will return a list of data with the `input_batch_size`
for each predecessor. In addition, it will remove the data used to create the
batch from the step's data.
Returns:
A tuple containing the list of data needed to create a batch for the step to
process, a dictionary with the sequence numbers of the batches that were used
to create the batch and the list of steps to which the batch was routed to if
the step is a normal step.
"""
if self.accumulate:
# Steps accumulating cannot receive routed batches
return self._get_data_for_accumulate() + ([],)
if self.convergence_step:
# Convergence steps will receive routed batches, but we need to clean the
# `batch_routed_to` list
return self._get_data_for_convergence_step() + ([],)
return self._get_data_normal()
def _get_data_for_accumulate(
self,
) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step
is accumulating data. It will return a list with all the data received from the
predecessors. In addition, it will remove the data used to create the batch from
the step's data.
Returns:
A tuple containing the list of data needed to create a batch for the step to
process and a dictionary with the sequence numbers of the batches that were
used to create the batch.
"""
data = []
batches_used = {}
for step_name, batches in self.data.items():
batches_used[step_name] = []
for batch in batches:
batches_used[step_name].append((batch.seq_no, batch.size, batch.size))
data.append([row for batch in batches for row in batch.get_data()])
# Reset the data buffer
self.data = {step_name: [] for step_name in self.data}
return data, batches_used
def _get_data_for_convergence_step(
self,
) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step is
a convergence step.
Returns:
A tuple containing the list of data needed to create a batch for the step to
process and a dictionary with the sequence numbers of the batches that were
used to create the batch.
"""
grouped_batches = self._group_batches_by_created_from()
seq_no, batches = grouped_batches[0]
str_seq_no = str(seq_no)
remaining_rows_per_step: Dict[str, int] = {
step_name: self.input_batch_size
for step_name in self.data # type: ignore
}
batches_used = defaultdict(list)
data = defaultdict(list)
for batch, batch_size in batches:
remaining_rows = remaining_rows_per_step[batch.step_name]
selected_data = batch.get_data(remaining_rows)
data[batch.step_name].extend(selected_data)
# If A -> [B, C] -> D, then in D (this step) we keep track of the remaining
# rows from the batches of A that B and C used to create the `batches`.
batch_size = self.convergence_step_batches_consumed.setdefault(
str_seq_no, {}
).get(batch.step_name, batch_size)
remaining_rows_in_batch = batch_size - len(selected_data)
self.convergence_step_batches_consumed[str_seq_no].update(
{batch.step_name: remaining_rows_in_batch}
)
# Update the remaining rows
num_rows = len(selected_data)
remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore
# Keep track of the batches used to create the batch
batches_used[batch.step_name].append((batch.seq_no, batch.size, num_rows))
# If the batch was entirely consumed, then remove it from the buffer
if len(batch.data[0]) == 0:
self.data[batch.step_name].remove(batch)
continue
# If all the batches grouped by the `seq_no` in `created_from` were consumed, then
# we can update the `next_expected_created_from_batch_seq_no` to the next one
# to avoid skipping batches.
no_remaining_rows = all(
count == 0
for count in self.convergence_step_batches_consumed[str_seq_no].values()
)
if no_remaining_rows:
self.next_expected_created_from_batch_seq_no += 1
return list(data.values()), dict(batches_used)
def _get_data_normal(
self,
) -> Tuple[
List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int, int]]], List[str]
]:
"""Gets the data needed to create a batch for the step to process when the step is
not accumulating data. It will return a list of data with the `input_batch_size`
for each predecessor. In addition, it will remove the data used to create the batch
from the step's data.
Returns:
A tuple containing the list of data needed to create a batch for the step to
process, a dictionary with the sequence numbers of the batches that were used
to create the batch and the list of steps to which the batch was routed to if
the step is a convergence step.
"""
data = []
batches_used = defaultdict(list)
batch_routed_to = []
for step_name in self.data:
# For each step batches buffer, we will create a batch with the `input_batch_size`
# using the data from the buffer. We will remove the consumed batches (no data
# left) and update the batch data with the remaining data.
step_data = []
idx_drop_batches = []
remaining_rows: int = self.input_batch_size # type: ignore
next_expected_seq_no = None
for idx, batch in enumerate(self.data[step_name]):
if remaining_rows == 0:
break
# Get `remaining_rows` or the remaining rows in the batch and add it to
# the step data that will be used to create the batch
selected_data = batch.get_data(remaining_rows)
step_data.extend(selected_data)
batch_routed_to = batch.batch_routed_to
# Update the remaining rows
num_rows = len(selected_data)
remaining_rows -= num_rows
# Keep track of the batches used to create the batch
batches_used[step_name].append((batch.seq_no, batch.size, num_rows))
next_expected_seq_no = batch.seq_no
# If the batch was entirely consumed, then remove it from the buffer
if len(batch.data[0]) == 0:
next_expected_seq_no += 1
idx_drop_batches.append(idx)
continue
# Remove the batches that were entirely consumed
idx_drop_batches.reverse()
for idx in idx_drop_batches:
self.data[step_name].pop(idx)
# Update the `next_expected_seq_no` from `step_name`. It can happen that:
# 1. This step didn't receive one batch because it was routed to other batches
# and `set_next_expected_seq_no` method was called. If the first element
# is not equal to the second, that means there is a potential `next_expected_seq_no`
# from `step_name`. If there is no data left, then we set that as the `next_expected_seq_no`.
# 2. `set_next_expected_seq_no` has not been called, so we set the `next_expected_seq_no`
# taking into account the data left in the step.
step_next_expected_seq_no = self.next_expected_seq_no[step_name]
if step_next_expected_seq_no[0] != step_next_expected_seq_no[1] and (
not self.data[step_name]
or (
self.data[step_name]
and self.data[step_name][0].seq_no >= step_next_expected_seq_no[1]
)
):
self.next_expected_seq_no[step_name] = (
step_next_expected_seq_no[1],
step_next_expected_seq_no[1],
)
elif next_expected_seq_no:
self.next_expected_seq_no[step_name] = (
next_expected_seq_no,
next_expected_seq_no
if next_expected_seq_no > step_next_expected_seq_no[1]
else step_next_expected_seq_no[1],
)
data.append(step_data)
return data, dict(batches_used), batch_routed_to
def _ready_to_create_batch(self) -> bool:
"""Checks if there is enough data to create a batch for the step.
Returns:
`True` if there is enough data to create a batch for the step. Otherwise,
`False`.
"""
if self.accumulate:
return self._ready_to_create_batch_accumulate()
if self.convergence_step:
return self._ready_to_create_batch_convergence_step()
return self._ready_to_create_batch_normal()
def _ready_to_create_batch_accumulate(self) -> bool:
"""Checks if there is enough data for an step accumulating data. It will return
`True` if the last batch was received from all the predecessors.
Returns:
`True` if ready to create a batch, `False` otherwise.
"""
return all(
step in self.last_batch_received
and sum(len(batch.data[0]) for batch in batches) > 0
for step, batches in self.data.items()
)
def _ready_to_create_batch_convergence_step(self) -> bool:
"""Checks if there is enough data for creating a batch for an step in which output
batches that were generated by steps that received routed batches are received.
It will return `True`, if all the output batches that were generated from a routed
batch have been received.
Returns:
`True` if ready to create a batch, `False` otherwise.
"""
grouped_batches = self._group_batches_by_created_from()
if not grouped_batches:
return False
seq_no, batches = grouped_batches[0]
# If the `seq_no` from the `created_from` field is not the expected one, then
# we cannot create a batch yet or the order will be messed up
if seq_no != self.next_expected_created_from_batch_seq_no:
return False
# Not all output batches to which the input batch was routed to haven't been
# received
batch_routed_to = batches[0][0].batch_routed_to
batches_received_from = {batch.step_name for batch, _ in batches}
if any(step_name not in batches_received_from for step_name in batch_routed_to):
return False
# There are output batches to which the input batch was routed to from all
# the steps. Check if there is enough data for creating a batch with `input_batch_size`
rows_per_step = defaultdict(lambda: 0)
for batch, _ in batches:
num_rows = len(batch.data[0])
rows_per_step[batch.step_name] += num_rows
# If there aren't at least `input_batch_size` rows from each step, then there
# isn't enough data to create a batch
if not all(
num_rows >= self.input_batch_size or step_name in self.last_batch_received # type: ignore
for step_name, num_rows in rows_per_step.items()
):
return False
return True
def _ready_to_create_batch_normal(self) -> bool:
"""Checks if there is enough data for creating a batch for a normal step. It will
be `True` it there are at least `input_batch_size` rows from each predecessor step.
Returns:
`True` if ready to create a batch, `False` otherwise.
"""
for step_name, batches in self.data.items():
# Depending on the number of replicas of the `Step` it can happen that some
# replica is faster and send batch with `seq_no==1` faster than the other that
# sends the batch with `seq_no==0`. We need to check which `seq_no` was expected
# next to not mess up the ordering of the rows.
next_expected_seq_no = self.next_expected_seq_no[step_name][0]
# `batches` are sorted by `seq_no`
num_rows = 0
is_batch_in_order = True
for batch in batches:
# Need to create batches using the data from batches with sequential `seq_no`
if batch.seq_no != next_expected_seq_no:
is_batch_in_order = False
break
# There are enough rows to create a batch
num_rows += len(batch.data[0])
if self.input_batch_size and num_rows >= self.input_batch_size:
break
next_expected_seq_no += 1
# If there are now rows but the last batch was already received, then there
# are no more batches to be created
if num_rows == 0 and step_name in self.last_batch_received:
return False
# If there are not enough rows and the last batch was not received yet, then
# there is not enough data yet to create a batch
# If the last batch was received, the batch preceding it must be in order
if (
self.input_batch_size
and num_rows < self.input_batch_size
and not (step_name in self.last_batch_received and is_batch_in_order)
):
return False
return True
def _last_batch(self) -> bool:
"""Checks if the batch to be created is the last one i.e. if the last batch was
received from all the predecessors.
Returns:
`True` if the batch to be created is the last one. Otherwise, `False`.
"""
if self.accumulate:
return self._last_batch_accumulate()
if self.convergence_step:
return self._last_batch_convergence_step()
return self._last_batch_normal()
def _update_offset(
self, created_from: Dict[str, List[Tuple[int, int, int]]]
) -> None:
"""Update the offset for the batch buffers of the upstream steps.
Args:
created_from: A dictionary containing which batches from which steps were used
to created this batch. The keys are the names of the steps and the values
are lists for each step containing the `seq_no` of each batch used, the original containing the `seq_no` of the batches of the steps that
size of the batch used and the number of rows used from the batch to create
this batch.
"""
for predecessor, seq_no_and_batch in created_from.items():
prev_last_batch_seq_no, prev_last_batch_offset = self.step_offset[
predecessor
]
last_batch_seq_no, _, last_batch_size = seq_no_and_batch[-1]
batch_offset = (
prev_last_batch_offset + last_batch_size
if prev_last_batch_seq_no == last_batch_seq_no
else last_batch_size
)
last_batch_seq_no = (
last_batch_seq_no
if last_batch_seq_no > prev_last_batch_seq_no
else prev_last_batch_seq_no
)
self.step_offset[predecessor] = (last_batch_seq_no, batch_offset)
def _last_batch_accumulate(self) -> bool:
"""Checks if the batch to be created is the last one for an step accumulating data.
`True` if the last batch was received from all the predecessors.
Returns:
`True` if the batch to be created is the last one. Otherwise, `False`.
"""
return all(step in self.last_batch_received for step in self.data.keys())
def _last_batch_convergence_step(self) -> bool:
"""Checks if the batch to be created is the last one for a convergence step. `True`
if the last batch of all the steps (`batch_routed_to`) in the last routed batch
have been received.
Returns:
`True` if the batch to be created is the last one. Otherwise, `False`.
"""
grouped_batches = self._group_batches_by_created_from()
if not grouped_batches:
return False
_, batches = grouped_batches[0]
for batch, _ in batches:
if not batch.last_batch:
return False
if len(batch.data[0]) > self.input_batch_size: # type: ignore
return False
return True
def _last_batch_normal(self) -> bool:
"""Checks if the batch to be created is the last one for a normal step. `True` if
there is no more data to be received from the predecessors.
Returns:
`True` if the batch to be created is the last one. Otherwise, `False`.
"""
for step_name, batches in self.data.items():
if step_name not in self.last_batch_received:
return False
num_rows = sum(len(batch.data[0]) for batch in batches)
if self.input_batch_size and num_rows > self.input_batch_size:
return False
return True
def _group_batches_by_created_from(
self,
) -> List[Tuple[int, List[Tuple["_Batch", int]]]]:
"""Group the batches by the first key of `created_from` field. This method is
meant to be used only with a `convergence_step`.
Returns:
A list of the batches grouped by the `seq_no` of the first step name in `created_from`.
The list is sorted by the `seq_no`.
"""
grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list)
for batches in self.data.values():
for batch in batches:
first_key = next(iter(batch.created_from))
batch_seq_no, batch_size, _ = batch.created_from[first_key][0]
grouped_batches[batch_seq_no].append((batch, batch_size))
return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items())
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the content of the `_BatchManagerStep` to a dictionary.
Args:
obj: Unused, just kept to match the signature of the parent method.
kwargs: Additional arguments that are kept to match the signature of the parent method.
Returns:
Internal representation of the `_BatchManagerStep`.
"""
return {
"step_name": self.step_name,
"accumulate": self.accumulate,
"input_batch_size": self.input_batch_size,
"data": {
step_name: [batch.dump(**kwargs) for batch in batches]
for step_name, batches in self.data.items()
},
"built_batches": [batch.dump(**kwargs) for batch in self.built_batches],
"seq_no": self.seq_no,
"last_batch_received": self.last_batch_received,
"convergence_step": self.convergence_step,
"convergence_step_batches_consumed": self.convergence_step_batches_consumed,
"next_expected_created_from_batch_seq_no": self.next_expected_created_from_batch_seq_no,
"next_expected_seq_no": self.next_expected_seq_no,
"step_signature": self.step_signature,
"use_cache": self.use_cache,
"step_offset": self.step_offset,
}
@property
def signature(self) -> str:
return f"{self.step_name}_{self.step_signature}"
class _BatchManager(_Serializable):
"""Class to manage the batches received from the steps. It keeps track of the
received batches and returns new batches for the steps to process based on their
input batch size and the batches received from the predecessors.
Attributes:
steps: A dictionary with the step name as the key and a `_BatchManagerStep`
instance as the value.
last_batch_received: A dictionary with the step name as the key and a flag to
indicate whether we received the last batch from the step.
"""
def __init__(
self,
steps: Dict[str, _BatchManagerStep],
last_batch_received: Dict[str, Union[_Batch, None]],
last_batch_sent: Dict[str, Union[_Batch, None]],
last_batch_flag_sent_to: List[str],
received_batch_seq_nos: Dict[str, List[int]],
) -> None:
"""Initialize the `_BatchManager` instance.
Args:
steps: A dictionary with the step name as the key and a dictionary with the
predecessor step name as the key and a list of batches as the value.
last_batch_received: A dictionary with the step name as the key and the last
`_Batch` received from the step.
last_batch_sent: A dictionary with the step name as the key and the last
`_Batch` sent to the step.
last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG`
was sent.
received_batch_seq_nos: a dictionary containing the list of batches sequence
numbers received per step.
"""
self._steps = steps
self._last_batch_received = last_batch_received
self._last_batch_sent = last_batch_sent
self._last_batch_flag_sent_to = last_batch_flag_sent_to
self._received_batch_seq_nos = received_batch_seq_nos
def _missing_seq_no(self, last_batch: _Batch) -> bool:
"""Checks if there's any missing sequence number in the batches received from the
step.
Args:
last_batch: the batch with `last_batch==True` received from the step.
Returns:
`True` if there's any missing sequence number, `False` otherwise.
"""
received_batch_seq_nos = self._received_batch_seq_nos[last_batch.step_name]
for i in range(last_batch.seq_no + 1):
if i not in received_batch_seq_nos:
return True
return False
def can_generate(self) -> bool:
"""Checks if there are still batches to be processed by the steps.
Returns:
`True` if there are still batches to be processed by the steps. Otherwise,
`False`.
"""
for step_name, batch in self._last_batch_received.items():
if step_name not in self._last_batch_flag_sent_to:
if not batch:
return True
if batch.last_batch and self._missing_seq_no(batch):
return True
if not batch.last_batch:
return True
if not self.get_last_batch_sent(step_name):
return True
return False
def register_batch(
self, batch: _Batch, steps_data_path: Optional["StrOrPath"] = None
) -> None:
"""Method to register a batch received from a step. It will keep track of the
sequence number and the last batch received from the step in the internal maps.
Args:
batch: _Batch from which we will register the sequence number and the last batch received.
steps_data_path: The path where the outputs of each `Step` (considering its
signature) will be saved for later reuse in another pipelines executions.
"""
step_name = batch.step_name
seq_no = batch.seq_no
self._received_batch_seq_nos[step_name].append(seq_no)
last_batch = self._last_batch_received[step_name]
if not last_batch or (last_batch and last_batch.seq_no < seq_no):
self._last_batch_received[step_name] = batch
if steps_data_path:
self.write_batch_data(batch, steps_data_path)
def write_batch_data(self, batch: _Batch, steps_data_path: Path) -> None:
"""Writes the batch to the steps data directory.
Argument:
batch: the batch to be written.
steps_data_path: the steps data base directory.
"""
step = self._steps[batch.step_name]
batch_manager_data_dir = Path(steps_data_path) / step.signature
batch_manager_data_dir.mkdir(parents=True, exist_ok=True)
filename = batch_manager_data_dir / f"batch_{batch.seq_no}.json"
if not filename.exists():
self.save(path=filename, format="json", dump=batch.dump())
def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
"""Gets the last batch received from a step.
Args:
step_name: The name of the step.
Returns:
The last batch received from the step or `None` if no batch was received.
"""
return self._last_batch_received.get(step_name)
def add_batch(
self,
to_step: str,
batch: _Batch,
prepend: bool = False,
) -> None:
"""Add an output batch from `batch.step_name` to `to_step`.
Args:
to_step: The name of the step that will process the batch.
batch: The output batch of an step to be processed by `to_step`.
prepend: If `True`, the content of the batch will be added at the start of
the buffer.
Raises:
ValueError: If `to_step` is not found in the batch manager.
"""
if to_step not in self._steps:
raise ValueError(f"Step '{to_step}' not found in the batch manager.")
step = self._steps[to_step]
step.add_batch(batch, prepend)
def add_batch_to_recover_offline_batch_generation(
self, to_step: str, data: List[List[Dict[str, Any]]]
) -> None:
"""Add a batch to recover pipeline execution from an `_Step` that used an `LLM`
with offline batch generation. It will add the batch to the start of the buffer
of the step and set the last batch received of the step to `None`.
Args:
to_step: The name of the step that will process the batch.
data: The data that was used with the offline batch generation.
"""
self.add_batch(
to_step=to_step,
batch=_Batch(seq_no=0, step_name=to_step, last_batch=True, data=data),
prepend=True,
)
self._last_batch_received[to_step] = None
def get_batch(self, step_name: str) -> Union[_Batch, None]:
"""Get the next batch to be processed by the step.
Args:
step_name: The name of the step that will process the batch.
Returns:
A `_Batch` instance if there is a batch to be processed by the step. Otherwise,
`None`.
"""
if step_name not in self._steps:
raise ValueError(f"Step '{step_name}' not found in the batch manager.")
return self._steps[step_name].get_batch()
def step_empty_buffers(self, step_name: str) -> List[str]:
"""Checks if the input buffer for a step is empty.
Args:
step_name: The name of the step.
Returns:
The name of the previous steps for which the input buffer for this step is
empty.
"""
return self._steps[step_name].empty_buffers()
def set_last_batch_sent(self, batch: "_Batch") -> None:
"""Set the last batch sent to a step.
Args:
batch: The last batch sent to a step.
"""
self._last_batch_sent[batch.step_name] = batch
def get_last_batch_sent(self, step_name: str) -> Union["_Batch", None]:
"""Get the last batch sent to a step.
Args:
step_name: The name of the step.
Returns:
The last batch sent to a step or `None` if no batch was sent.
"""
return self._last_batch_sent.get(step_name, None)
def set_last_batch_flag_sent_to(self, step_name: str) -> None:
"""Set the flag to indicate that the last batch was sent to a step.
Args:
step_name: The name of the step.
"""
self._last_batch_flag_sent_to.append(step_name)
def set_next_expected_seq_no(
self, step_name: str, from_step: str, next_expected_seq_no: int
) -> None:
"""Sets the next expected sequence number of a `_Batch` received by `step` coming
from `from_step`.
Args:
step_name: The step name whose next expected sequence number for `from_step`
has to be updated.
from_step: The name of the step from which its next expected sequence number
in step has to be updated.
next_expected_seq_no: the next expected sequence number of a `_Batch` coming
from `from_step`.
"""
self._steps[step_name].set_next_expected_seq_no(from_step, next_expected_seq_no)
def step_has_finished(self, step_name: str) -> bool:
"""Indicates if the step has finished by checking if it sent a batch with `last_batch==True`
or it was sent the `LAST_BATCH_SENT_FLAG`.
Args:
step_name: the name of the step to be checked.
Returns:
`True` if step has finished generating batches, `False` otherwise.
"""
return step_name in self._last_batch_flag_sent_to or (
self._last_batch_received[step_name] is not None
and self._last_batch_received[step_name].last_batch # type: ignore
)
@classmethod
def from_dag( # noqa: C901
cls, dag: "DAG", use_cache: bool = False, steps_data_path: Optional[Path] = None
) -> "_BatchManager":
"""Create a `_BatchManager` instance from a `DAG` instance.
Args:
dag: The `DAG` instance.
use_cache: whether or not to try loading outputs from steps of previous pipelines
executions. Defaults to `False`.
steps_data_path: The path where the outputs of each `Step` (considering its
signature) will be saved for later reuse in another pipelines executions.
Returns:
A `_BatchManager` instance.
"""
steps = {}
last_batch_received = {}
last_batch_sent = {}
last_batch_flag_sent_to = []
received_batch_seq_nos = {}
load_batches = {}
steps_to_load_data_from_previous_executions: Dict[str, Union[Path, None]] = {}
for step_name in dag:
step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME]
last_batch_received[step.name] = None
last_batch_sent[step.name] = None
received_batch_seq_nos[step.name] = []
predecessors = list(dag.get_step_predecessors(step_name))
convergence_step = all(
dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False)
for predecessor in predecessors
)
batch_manager_step = _BatchManagerStep.from_step(
step=step,
predecessors=predecessors,
convergence_step=convergence_step,
)
all_step_precessors_use_cache = all(
dag.get_step(step_name)[STEP_ATTR_NAME].use_cache
for step_name in predecessors
)
if use_cache and step.use_cache and all_step_precessors_use_cache:
step_data_path = steps_data_path / batch_manager_step.signature
if step_data_path.exists():
steps_to_load_data_from_previous_executions[step_name] = (
step_data_path
)
# We only want to load the outputs that are directly needed by the added
# steps, so if we need to load the outputs of one step and one of its
# predecessors it's in the list, then we remove it.
for predecessor in predecessors:
if predecessor in steps_to_load_data_from_previous_executions:
steps_to_load_data_from_previous_executions[predecessor] = (
None
)
steps[step_name] = batch_manager_step
for (
step_name,
step_outputs_path,
) in steps_to_load_data_from_previous_executions.items():
last_batch_flag_sent_to.append(step_name)
if step_outputs_path is None:
continue
load_batches[step_name] = sorted(
[
_Batch.from_json(batch_file)
for batch_file in step_outputs_path.glob("*.json")
if batch_file.is_file() and batch_file.suffix == ".json"
],
key=lambda x: x.seq_no,
)
last_batch_received[step_name] = load_batches[step_name][-1]
# Load batches from previous steps in batch manager steps
for step_name, batch_manager_step in steps.items():
for predecessor in dag.get_step_predecessors(step_name):
if predecessor in load_batches:
batch_manager_step.data[predecessor] = deepcopy(
load_batches[predecessor]
)
batch_manager_step.last_batch_received.append(predecessor)
return cls(
steps,
last_batch_received,
last_batch_sent,
last_batch_flag_sent_to,
received_batch_seq_nos,
)
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the content of the `_BatchManager` to a dictionary.
Args:
obj (Any): Unused, just kept to match the signature of the parent method.
kwargs (Any): Additional arguments that are kept to match the signature of the parent method.
Returns:
Dict[str, Any]: Internal representation of the `_BatchManager`.
"""
return {
"steps": {name: step.dump(**kwargs) for name, step in self._steps.items()},
"last_batch_received": {
step_name: batch.dump(**kwargs) if batch is not None else None
for step_name, batch in self._last_batch_received.items()
},
"last_batch_sent": {
step_name: batch.dump(**kwargs) if batch is not None else None
for step_name, batch in self._last_batch_sent.items()
},
"last_batch_flag_sent_to": self._last_batch_flag_sent_to,
"received_batch_seq_nos": self._received_batch_seq_nos,
}
def cache(self, path: Path, steps_data_path: Path) -> None: # noqa: C901
"""Cache the `_BatchManager` to a file.
Args:
path: The path to the file where the `_BatchManager` will be cached. If `None`,
then the `_BatchManager` will be cached in the default cache folder.
steps_data_path: The path where the outputs of each `Step` (considering its
signature) will be saved for later reuse in another pipelines executions.
"""
def save_batch(
batches_dir: Path, batch_dump: Dict[str, Any], batch_list: List[_Batch]
) -> Path:
seq_no = batch_dump["seq_no"]
data_hash = batch_dump["data_hash"]
batch_file = batches_dir / f"batch_{seq_no}_{data_hash}.json"
# Save the batch if it doesn't exist
if not batch_file.exists():
# Get the data of the batch before saving it
batch = next(batch for batch in batch_list if batch.seq_no == seq_no)
batch_dump["data"] = batch.data
self.save(path=batch_file, format="json", dump=batch_dump)
return batch_file
def remove_files(keep_files: List[str], dir: Path) -> None:
files = list_files_in_dir(dir, key=None)
remove = set(files) - {Path(file) for file in keep_files}
for file in remove:
file.unlink()
path = Path(path)
# Do not include `_Batch` data so `dump` is fast
dump = self.dump(include_batch_data=False)
batch_manager_step_files = {}
# Do this to avoid modifying the dictionary while iterating over it
batch_manager_steps = set(dump["steps"].keys())
for step_name in batch_manager_steps:
step_dump = dump["steps"].pop(step_name)
# Create a directory for each batch manager step to store their batches
batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name
batch_manager_step_dir.mkdir(parents=True, exist_ok=True)
# Store each built `_Batch` in a separate file
built_batches_dir = batch_manager_step_dir / "built_batches"
built_batches_dir.mkdir(parents=True, exist_ok=True)
step_dump["built_batches"] = [
str(
save_batch(
batches_dir=built_batches_dir,
batch_dump=batch_dump,
batch_list=self._steps[step_name].built_batches,
)
)
for batch_dump in step_dump["built_batches"]
]
# Remove built `_Batch`es that were consumed from cache
remove_files(step_dump["built_batches"], built_batches_dir)
# Store the `_BatchManagerStep` info
batch_manager_step_file = str(
path.parent / f"batch_manager_steps/{step_name}/batch_manager_step.json"
)
self.save(path=batch_manager_step_file, format="json", dump=step_dump)
# Store the path to the `_BatchManagerStep` file
batch_manager_step_files[step_name] = batch_manager_step_file
dump["steps"] = batch_manager_step_files
self.save(path=path, format="json", dump=dump)
@classmethod
def load_from_cache(
cls, dag: "DAG", batch_manager_path: "StrOrPath", steps_data_path: "StrOrPath"
) -> "_BatchManager":
"""Loads the `_BatchManager` from a cache file.
Args:
path: The path to the cache file.
"""
_check_is_dir(batch_manager_path)
content = read_json(batch_manager_path)
# Read each `_BatchManagerStep` from file
steps = {}
for step_name, step_file in content["steps"].items():
steps[step_name] = read_json(step_file)
# When reading back from JSON, `next_expected_seq_no` and `step_offset` is a
# list (because JSON files do not have tuples).
steps[step_name]["next_expected_seq_no"] = {
k: tuple(v) for k, v in steps[step_name]["next_expected_seq_no"].items()
}
steps[step_name]["step_offset"] = {
k: tuple(v) for k, v in steps[step_name]["step_offset"].items()
}
# TODO: where are we writing built batches now? xD
# Read each `_Batch` from file
steps[step_name]["built_batches"] = [
read_json(batch) for batch in steps[step_name]["built_batches"]
]
# Read the batches from the `steps_data` directory to populate back the `_BatchManagerStep`
step_offset = steps[step_name]["step_offset"]
for successor_step_name, offset in step_offset.items():
batch_offset, batch_row_offset = offset
step: "_Step" = dag.get_step(successor_step_name)[STEP_ATTR_NAME]
successor_step_data_path = (
steps_data_path / f"{step.name}_{step.signature}"
)
# read batches from successor step from the step data directory taking into
# account offset
batches = []
for batch_file in successor_step_data_path.glob("*.json"):
if not batch_file.is_file() or batch_file.suffix != ".json":
continue
# If the batch number is lower than the batch offset then we should
# skip it as it has already been processed by the step
batch_no = int(batch_file.stem.split("batch_")[1])
if batch_no < batch_offset:
continue
# read the batch and skip the first N rows of the first batch
batch = read_json(batch_file)
if batch_no == batch_offset:
batch["data"][0] = batch["data"][0][batch_row_offset:]
batches.append(batch)
# sort batches by `seq_no` as it's a requirement for checking if ready to
# create next batch
batches.sort(key=lambda batch: batch["seq_no"])
steps[step_name]["data"][successor_step_name] = batches
content["steps"] = steps
return cls.from_dict(content)
def invalidate_cache_for(
self, step_name: str, dag: "DAG", steps_data_path: Path
) -> None:
"""Invalidates the cache for the given step and its predecessors.
Args:
step_name: the name of the step for which the cache will be invalidated.
dag: the `DAG` of the pipeline containing the steps.
steps_data_path: the path where the output batches of each `Step` were saved
for reuse in another pipeline execution.
"""
invalidate_if_predecessor = []
for sorted_step in dag:
if (sorted_step == step_name) or any(
predecessor in invalidate_if_predecessor
for predecessor in dag.get_step_predecessors(sorted_step)
):
self._reset_batch_manager_for_step(sorted_step, dag)
invalidate_if_predecessor.append(sorted_step)
self._load_predecessor_batches(step_name, dag, steps_data_path)
def _reset_batch_manager_for_step(self, step_name: str, dag: "DAG") -> None:
"""Resets the batch manager state for a given step i.e. creates a new clean `_BatchManagerStep`
for the step and removes the step name from the lists of states of the `BatchManager`.
Args:
step_name: the name of step for which its batch manager state needs to be cleaned.
dag: the `DAG` of the pipeline containing the steps.
"""
predecessors = list(dag.get_step_predecessors(step_name))
convergence_step = dag.is_convergence_step(step_name)
step = dag.get_step(step_name)[STEP_ATTR_NAME]
self._steps[step_name] = _BatchManagerStep.from_step(
step, predecessors=predecessors, convergence_step=convergence_step
)
self._last_batch_received[step_name] = None
self._last_batch_sent[step_name] = None
if step_name in self._last_batch_flag_sent_to:
self._last_batch_flag_sent_to.remove(step_name)
def _load_predecessor_batches(
self, step_name: str, dag: "DAG", steps_data_path: Path
) -> None:
"""Loads the cached batches of the predecessors of the step in its `_BatchManagerStep`.
Args:
step_name: the name of the step whose predecessors' batches will be loaded.
dag: the `DAG` of the pipeline containing the steps.
steps_data_path: the path where the output batches of each `Step` were saved
for reuse in another pipeline execution.
"""
for predecessor in dag.get_step_predecessors(step_name):
step_predecessor = dag.get_step(predecessor)[STEP_ATTR_NAME]
predecessor_step_data_path = (
steps_data_path
/ f"{step_predecessor.name}_{step_predecessor.signature}"
)
batch_files = list_files_in_dir(
predecessor_step_data_path, key=lambda x: int(x.stem.split("_")[-1])
)
for file in batch_files:
batch = _Batch.from_file(file)
if batch.last_batch:
self._steps[step_name].last_batch_received.append(batch.step_name)
self._steps[step_name].data[predecessor].append(batch)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing as mp
import signal
import sys
from multiprocessing.pool import Pool
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
cast,
)
import tblib
from distilabel.constants import SIGINT_HANDLER_CALLED_ENV_NAME
from distilabel.distiset import create_distiset
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.pipeline.base import BasePipeline, set_pipeline_running_env_variables
from distilabel.pipeline.ray import RayPipeline
from distilabel.pipeline.step_wrapper import _StepWrapper, _StepWrapperException
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.ray import script_executed_in_ray_cluster
if TYPE_CHECKING:
import logging
from queue import Queue
from distilabel.distiset import Distiset
from distilabel.steps.base import _Step
from distilabel.typing import InputDataset, LoadGroups
_SUBPROCESS_EXCEPTION: Union[Exception, None] = None
def _init_worker(
log_queue: "Queue[Any]", pipeline_name: str, pipeline_cache_id: str
) -> None:
"""Init function for the child processes that will execute the `Step`s of the `Pipeline`.
Args:
log_queue: The queue to send the logs to the main process.
"""
# Register a signal handler for SIGINT to avoid the default behavior of the process
# to terminate when the parent process receives a SIGINT signal. Instead, set an env
# variable when SIGINT is received. Child process can check the value of this env
# variable in sections of the code where they need to stop the execution if SIGINT
# was received (such as offline batch generation polling).
def signal_handler(sig: int, frame: Any) -> None:
import os
os.environ[SIGINT_HANDLER_CALLED_ENV_NAME] = "1"
signal.signal(signal.SIGINT, signal_handler)
set_pipeline_running_env_variables(pipeline_name, pipeline_cache_id)
setup_logging(log_queue)
# We create a custom `Pool` class so the created processes are not daemons, allowing
# them to create child processes if necessary (for example when using `vLLM` with `tensor_parallel_size`)
# https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic
class _NoDaemonProcess(mp.Process):
@property
def daemon(self) -> bool:
return False
@daemon.setter
def daemon(self, value: bool) -> None: # type: ignore
pass
class _NoDaemonContext(type(mp.get_context())):
Process = _NoDaemonProcess
class _NoDaemonPool(Pool):
def __init__(
self,
processes: Union[int, None] = None,
initializer: Union[Callable[..., object], None] = None,
initargs: Iterable[Any] = ..., # type: ignore
maxtasksperchild: Union[int, None] = None,
) -> None:
super().__init__(
processes=processes,
initializer=initializer,
initargs=initargs,
maxtasksperchild=maxtasksperchild,
context=_NoDaemonContext(), # type: ignore
)
class Pipeline(BasePipeline):
"""Local pipeline implementation using `multiprocessing`."""
def ray(
self,
ray_head_node_url: Optional[str] = None,
ray_init_kwargs: Optional[Dict[str, Any]] = None,
) -> RayPipeline:
"""Creates a `RayPipeline` using the init parameters of this pipeline. This is a
convenient method that can be used to "transform" one common `Pipeline` to a `RayPipeline`
and it's mainly used by the CLI.
Args:
ray_head_node_url: The URL that can be used to connect to the head node of
the Ray cluster. Normally, you won't want to use this argument as the
recommended way to submit a job to a Ray cluster is using the [Ray Jobs
CLI](https://docs.ray.io/en/latest/cluster/running-applications/job-submission/index.html#ray-jobs-overview).
Defaults to `None`.
ray_init_kwargs: kwargs that will be passed to the `ray.init` method. Defaults
to `None`.
Returns:
A `RayPipeline` instance.
"""
pipeline = RayPipeline(
name=self.name,
description=self.description,
cache_dir=self._cache_dir,
enable_metadata=self._enable_metadata,
requirements=self.requirements,
ray_head_node_url=ray_head_node_url,
ray_init_kwargs=ray_init_kwargs,
)
pipeline.dag = self.dag
return pipeline
def run(
self,
parameters: Optional[Dict[Any, Dict[str, Any]]] = None,
load_groups: Optional["LoadGroups"] = None,
use_cache: bool = True,
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
dataset_batch_size: int = 50,
logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline.
Args:
parameters: A dictionary with the step name as the key and a dictionary with
the runtime parameters for the step as the value. Defaults to `None`.
load_groups: A list containing lists of steps that have to be loaded together
and in isolation with respect to the rest of the steps of the pipeline.
This argument also allows passing the following modes:
- "sequential_step_execution": each step will be executed in a stage i.e.
the execution of the steps will be sequential.
Defaults to `None`.
use_cache: Whether to use the cache from previous pipeline runs. Defaults to
`True`.
storage_parameters: A dictionary with the storage parameters (`fsspec` and path)
that will be used to store the data of the `_Batch`es passed between the
steps if `use_fs_to_pass_data` is `True` (for the batches received by a
`GlobalStep` it will be always used). It must have at least the "path" key,
and it can contain additional keys depending on the protocol. By default,
it will use the local file system and a directory in the cache directory.
Defaults to `None`.
use_fs_to_pass_data: Whether to use the file system to pass the data of
the `_Batch`es between the steps. Even if this parameter is `False`, the
`Batch`es received by `GlobalStep`s will always use the file system to
pass the data. Defaults to `False`.
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
dataset_batch_size: if `dataset` is given, this will be the size of the batches
yield by the `GeneratorStep` created using the `dataset`. Defaults to `50`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
Raises:
RuntimeError: If the pipeline fails to load all the steps.
"""
if script_executed_in_ray_cluster():
print("Script running in Ray cluster... Using `RayPipeline`...")
return self.ray().run(
parameters=parameters,
use_cache=use_cache,
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
dataset_batch_size=dataset_batch_size,
)
self._log_queue = cast("Queue[Any]", mp.Queue())
if distiset := super().run(
parameters=parameters,
load_groups=load_groups,
use_cache=use_cache,
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
dataset_batch_size=dataset_batch_size,
logging_handlers=logging_handlers,
):
return distiset
num_processes = self.dag.get_total_replica_count()
with (
mp.Manager() as manager,
_NoDaemonPool(
num_processes,
initializer=_init_worker,
initargs=(
self._log_queue,
self.name,
self.signature,
),
) as pool,
):
self._manager = manager
self._pool = pool
self._output_queue = self.QueueClass()
self._load_queue = self.QueueClass()
self._handle_keyboard_interrupt()
# Run the loop for receiving the load status of each step
self._load_steps_thread = self._run_load_queue_loop_in_thread()
# Start a loop to receive the output batches from the steps
self._output_queue_thread = self._run_output_queue_loop_in_thread()
self._output_queue_thread.join()
self._teardown()
if self._exception:
raise self._exception
distiset = create_distiset(
self._cache_location["data"],
pipeline_path=self._cache_location["pipeline"],
log_filename_path=self._cache_location["log_file"],
enable_metadata=self._enable_metadata,
dag=self.dag,
)
stop_logging()
return distiset
@property
def QueueClass(self) -> Callable:
"""The callable used to create the input and output queues.
Returns:
The callable to create a `Queue`.
"""
assert self._manager, "Manager is not initialized"
return self._manager.Queue
def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> None:
"""Runs the `Step` wrapped in a `_ProcessWrapper` in a separate process of the
`Pool`.
Args:
step: The step to run.
input_queue: The input queue to send the data to the step.
replica: The replica ID assigned.
"""
assert self._pool, "Pool is not initialized"
step_wrapper = _StepWrapper(
step=step, # type: ignore
replica=replica,
input_queue=input_queue,
output_queue=self._output_queue,
load_queue=self._load_queue,
dry_run=self._dry_run,
ray_pipeline=False,
)
self._pool.apply_async(step_wrapper.run, error_callback=self._error_callback)
def _error_callback(self, e: BaseException) -> None:
"""Error callback that will be called when an error occurs in a `Step` process.
Args:
e: The exception raised by the process.
"""
global _SUBPROCESS_EXCEPTION
# First we check that the exception is a `_StepWrapperException`, otherwise, we
# print it out and stop the pipeline, since some errors may be unhandled
if not isinstance(e, _StepWrapperException):
self._logger.error(f"❌ Failed with an unhandled exception: {e}")
self._stop()
return
if e.is_load_error:
self._logger.error(f"❌ Failed to load step '{e.step.name}': {e.message}")
_SUBPROCESS_EXCEPTION = e.subprocess_exception
_SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string( # type: ignore
e.formatted_traceback
).as_traceback()
return
# If the step is global, is not in the last trophic level and has no successors,
# then we can ignore the error and continue executing the pipeline
step_name: str = e.step.name # type: ignore
if (
e.step.is_global
and not self.dag.step_in_last_trophic_level(step_name)
and list(self.dag.get_step_successors(step_name)) == []
):
self._logger.error(
f"✋ An error occurred when running global step '{step_name}' with no"
" successors and not in the last trophic level. Pipeline execution can"
f" continue. Error will be ignored."
)
self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}")
return
# Handle tasks using an `LLM` using offline batch generation
if isinstance(
e.subprocess_exception, DistilabelOfflineBatchGenerationNotFinishedException
):
self._logger.info(
f"⏹️ '{e.step.name}' task stopped pipeline execution: LLM offline batch"
" generation in progress. Rerun pipeline with cache to check results and"
" continue execution."
)
self._set_step_for_recovering_offline_batch_generation(e.step, e.data) # type: ignore
with self._stop_called_lock:
if not self._stop_called:
self._stop(acquire_lock=False)
return
# Global step with successors failed
self._logger.error(f"An error occurred in global step '{step_name}'")
self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}")
self._stop()
def _teardown(self) -> None:
"""Clean/release/stop resources reserved to run the pipeline."""
if self._write_buffer:
self._write_buffer.close()
if self._batch_manager:
self._batch_manager = None
self._stop_load_queue_loop()
self._load_steps_thread.join()
if self._pool:
self._pool.terminate()
self._pool.join()
if self._manager:
self._manager.shutdown()
self._manager.join()
def _set_steps_not_loaded_exception(self) -> None:
"""Raises a `RuntimeError` notifying that the steps load has failed.
Raises:
RuntimeError: containing the information and why a step failed to be loaded.
"""
self._exception = RuntimeError(
"Failed to load all the steps. Could not run pipeline."
)
self._exception.__cause__ = _SUBPROCESS_EXCEPTION
def _stop(self, acquire_lock: bool = True) -> None:
"""Stops the pipeline execution. It will first send `None` to the input queues
of all the steps and then wait until the output queue is empty i.e. all the steps
finished processing the batches that were sent before the stop flag. Then it will
send `None` to the output queue to notify the pipeline to stop.
Args:
acquire_lock: Whether to acquire the lock to access the `_stop_called` attribute.
"""
if acquire_lock:
self._stop_called_lock.acquire()
if self._stop_called:
self._stop_calls += 1
if self._stop_calls == 1:
self._logger.warning("🛑 Press again to force the pipeline to stop.")
elif self._stop_calls > 1:
self._logger.warning("🛑 Forcing pipeline interruption.")
if self._pool:
self._pool.terminate()
self._pool.join()
self._pool = None
if self._manager:
self._manager.shutdown()
self._manager.join()
self._manager = None
stop_logging()
sys.exit(1)
return
self._stop_called = True
if acquire_lock:
self._stop_called_lock.release()
self._logger.debug(
f"Steps loaded before calling `stop`: {self._steps_load_status}"
)
self._logger.info(
"🛑 Stopping pipeline. Waiting for steps to finish processing batches..."
)
self._stop_output_queue_loop()
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from distilabel.constants import INPUT_QUEUE_ATTR_NAME, STEP_ATTR_NAME
from distilabel.distiset import create_distiset
from distilabel.errors import DistilabelUserError
from distilabel.models.llms.vllm import vLLM
from distilabel.pipeline.base import BasePipeline, set_pipeline_running_env_variables
from distilabel.pipeline.step_wrapper import _StepWrapper
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.serialization import TYPE_INFO_KEY
if TYPE_CHECKING:
import logging
from os import PathLike
from queue import Queue
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from distilabel.distiset import Distiset
from distilabel.steps.base import _Step
from distilabel.typing import InputDataset, LoadGroups
class RayPipeline(BasePipeline):
"""Ray pipeline implementation allowing to run a pipeline in a Ray cluster."""
def __init__(
self,
name: str,
description: Optional[str] = None,
cache_dir: Optional[Union[str, "PathLike"]] = None,
enable_metadata: bool = False,
requirements: Optional[List[str]] = None,
ray_head_node_url: Optional[str] = None,
ray_init_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the `RayPipeline` instance.
Args:
name: The name of the pipeline.
description: A description of the pipeline. Defaults to `None`.
cache_dir: A directory where the pipeline will be cached. Defaults to `None`.
enable_metadata: Whether to include the distilabel metadata column for the pipeline
in the final `Distiset`. It contains metadata used by distilabel, for example
the raw outputs of the `LLM` without processing would be here, inside `raw_output_...`
field. Defaults to `False`.
requirements: List of requirements that must be installed to run the Pipeline.
Defaults to `None`, but can be helpful to inform in a pipeline to be shared
that this requirements must be installed.
ray_head_node_url: The URL that can be used to connect to the head node of
the Ray cluster. Normally, you won't want to use this argument as the
recommended way to submit a job to a Ray cluster is using the [Ray Jobs
CLI](https://docs.ray.io/en/latest/cluster/running-applications/job-submission/index.html#ray-jobs-overview).
Defaults to `None`.
ray_init_kwargs: kwargs that will be passed to the `ray.init` method. Defaults
to `None`.
"""
super().__init__(name, description, cache_dir, enable_metadata, requirements)
self._ray_head_node_url = ray_head_node_url
self._ray_init_kwargs = ray_init_kwargs or {}
self._ray_node_ids = {}
def run(
self,
parameters: Optional[Dict[str, Dict[str, Any]]] = None,
load_groups: Optional["LoadGroups"] = None,
use_cache: bool = True,
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
dataset_batch_size: int = 50,
logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline in the Ray cluster.
Args:
parameters: A dictionary with the step name as the key and a dictionary with
the runtime parameters for the step as the value. Defaults to `None`.
load_groups: A list containing lists of steps that have to be loaded together
and in isolation with respect to the rest of the steps of the pipeline.
This argument also allows passing the following modes:
- "sequential_step_execution": each step will be executed in a stage i.e.
the execution of the steps will be sequential.
Defaults to `None`.
use_cache: Whether to use the cache from previous pipeline runs. Defaults to
`True`.
storage_parameters: A dictionary with the storage parameters (`fsspec` and path)
that will be used to store the data of the `_Batch`es passed between the
steps if `use_fs_to_pass_data` is `True` (for the batches received by a
`GlobalStep` it will be always used). It must have at least the "path" key,
and it can contain additional keys depending on the protocol. By default,
it will use the local file system and a directory in the cache directory.
Defaults to `None`.
use_fs_to_pass_data: Whether to use the file system to pass the data of
the `_Batch`es between the steps. Even if this parameter is `False`, the
`Batch`es received by `GlobalStep`s will always use the file system to
pass the data. Defaults to `False`.
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
dataset_batch_size: if `dataset` is given, this will be the size of the batches
yield by the `GeneratorStep` created using the `dataset`. Defaults to `50`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
Raises:
RuntimeError: If the pipeline fails to load all the steps.
"""
self._check_no_llms_using_offline_batch_generation()
self._init_ray()
self._log_queue = self.QueueClass(
actor_options={"name": f"distilabel-{self.name}-log-queue"}
)
if distiset := super().run(
parameters=parameters,
load_groups=load_groups,
use_cache=use_cache,
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
dataset_batch_size=dataset_batch_size,
logging_handlers=logging_handlers,
):
return distiset
self._logger.info(f"Ray nodes GPUs: {self._ray_node_ids}")
self._output_queue = self.QueueClass(
actor_options={"name": f"distilabel-{self.name}-output-queue"}
)
self._load_queue = self.QueueClass(
actor_options={"name": f"distilabel-{self.name}-load-queue"}
)
self._handle_keyboard_interrupt()
# Run the loop for receiving the load status of each step
self._load_steps_thread = self._run_load_queue_loop_in_thread()
# Start a loop to receive the output batches from the steps
self._output_queue_thread = self._run_output_queue_loop_in_thread()
self._output_queue_thread.join()
self._teardown()
if self._exception:
stop_logging()
raise self._exception
distiset = create_distiset(
self._cache_location["data"],
pipeline_path=self._cache_location["pipeline"],
log_filename_path=self._cache_location["log_file"],
enable_metadata=self._enable_metadata,
dag=self.dag,
)
stop_logging()
return distiset
def _check_no_llms_using_offline_batch_generation(self) -> None:
"""Checks if there are any `LLM` steps using the `offline_batch_generate` method
and raises an exception if so. This method is not supported in the Ray pipeline."""
for step_name in self.dag:
step: "_Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
if not hasattr(step, "llm"):
continue
if step.llm.use_offline_batch_generation: # type: ignore
raise DistilabelUserError(
f"Step '{step_name}' uses an `LLM` with offline batch generation because"
"`use_offline_batch_generation=True`. `LLM`s using this method are not"
" supported in the Ray pipeline.",
page="sections/how_to_guides/advanced/offline-batch-generation",
)
def _init_ray(self) -> None:
"""Inits or connects to a Ray cluster."""
try:
import ray
except ImportError as ie:
raise ImportError(
"ray is not installed. Please install it using `pip install 'distilabel[ray]'`."
) from ie
if self._ray_head_node_url:
ray.init(
self._ray_head_node_url,
runtime_env={"pip": self.requirements},
**self._ray_init_kwargs,
)
elif not ray.is_initialized():
# Init a local Ray cluster
ray.init(**self._ray_init_kwargs)
self._ray_node_ids = self._get_ray_gpus_per_node()
def _get_ray_gpus_per_node(self) -> Dict[str, int]:
"""Gets the number of GPUs per node in the Ray cluster.
Returns:
A dictionary in which the keys are the node IDs and the values the number of
GPUs per node.
"""
import ray
gpus_per_node = {}
for node in ray.nodes():
node_id = node["NodeID"]
gpus = int(node["Resources"].get("GPU", 0))
gpus_per_node[node_id] = gpus
return gpus_per_node
@property
def QueueClass(self) -> Callable:
from ray.util.queue import Queue
return Queue
def _create_step_input_queue(self, step_name: str) -> "Queue[Any]":
"""Creates an input queue for a step. Override to set actor name.
Args:
step_name: The name of the step.
Returns:
The input queue created.
"""
input_queue = self.QueueClass(
actor_options={"name": f"distilabel-{self.name}-input-queue-{step_name}"}
)
self.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, input_queue)
return input_queue
def _run_step(self, step: "_Step", input_queue: "Queue[Any]", replica: int) -> None:
"""Creates a replica of an `Step` using a Ray Actor.
Args:
step: The step to run.
input_queue: The input queue to send the data to the step.
replica: The replica ID assigned.
"""
import ray
@ray.remote
class _StepWrapperRay:
def __init__(
self,
step_wrapper: _StepWrapper,
log_queue: "Queue[Any]",
pipeline_name: str,
pipeline_cache_id: str,
) -> None:
self._step_wrapper = step_wrapper
self._log_queue = log_queue
self._pipeline_name = pipeline_name
self._pipeline_cache_id = pipeline_cache_id
def run(self) -> str:
setup_logging(log_queue=self._log_queue)
set_pipeline_running_env_variables(
self._pipeline_name, self._pipeline_cache_id
)
return self._step_wrapper.run()
resources: Dict[str, Any] = {
"name": f"distilabel-{self.name}-{step.name}-{replica}"
}
if hasattr(step, "llm") and isinstance(step.llm, vLLM): # type: ignore
resources["scheduling_strategy"] = self._create_vllm_placement_group(step)
else:
if step.resources.cpus is not None:
resources["num_cpus"] = step.resources.cpus
if step.resources.gpus is not None:
resources["num_gpus"] = step.resources.gpus
if step.resources.memory is not None:
resources["memory"] = step.resources.memory
if step.resources.resources is not None:
resources["resources"] = step.resources.resources
_StepWrapperRay = _StepWrapperRay.options(**resources) # type: ignore
self._logger.debug(
f"Creating Ray actor for '{step.name}' (replica ID: {replica}) with resources:"
f" {resources}"
)
step_wrapper = _StepWrapperRay.remote(
step_wrapper=_StepWrapper(
step=step, # type: ignore
replica=replica,
input_queue=input_queue,
output_queue=self._output_queue,
load_queue=self._load_queue,
dry_run=self._dry_run,
ray_pipeline=True,
),
log_queue=self._log_queue,
pipeline_name=self.name,
pipeline_cache_id=self.signature,
)
self._logger.debug(
f"Executing remote `run` method of Ray actor for '{step.name}' (replica ID:"
f" {replica})..."
)
step_wrapper.run.remote()
def _create_vllm_placement_group(
self, step: "_Step"
) -> "PlacementGroupSchedulingStrategy":
"""Creates a Ray placement group with as many GPU bundles as `tensor_parallel_size`
specified in the `vLLM` initialisation. The created placement group uses the `STRICT_PACK`
strategy if the `pipeline_parallel_size` is less or equal to 1, otherwise it uses
`SPREAD` (placement group with GPU bundles in several nodes). In addition, the created
placement group is targeted to be created in a specific node. This avoids having
`vLLM` raising the exception `Ray does not allocate any GPUs on the driver node...`,
as it assures that the driver `_StepWrapperRay` actor created resides in the same
node as the ray actors created by `vLLM` for the distributed inference.
Args:
step: the step which uses `vLLM`.
Returns:
A `PlacementGroupSchedulingStrategy` using the created `PlacementGroup`.
"""
import ray
llm = step.llm # type: ignore
tensor_parallel_size = llm.extra_kwargs.get("tensor_parallel_size", 1) # type: ignore
pipeline_parallel_size = llm.extra_kwargs.get("pipeline_parallel_size", 1) # type: ignore
# Calculate total GPUs needed
total_gpus_needed = tensor_parallel_size * pipeline_parallel_size
# Count available GPUs across all nodes
total_available_gpus = sum(self._ray_node_ids.values())
self._logger.info(
f"`vLLM` placement group for '{step.name}' step requires {total_gpus_needed}"
f" GPUs. Total available GPUs: {total_available_gpus}."
)
if total_available_gpus < total_gpus_needed:
raise ValueError(
f"Ray cluster does not allocate enough GPUs to create the placement group"
f" required by the `vLLM` instance of the step '{step.name}'."
f" Needed: {total_gpus_needed}, Available: {total_available_gpus}"
)
# Update the available GPU count
selected_node_id = None
gpus_left_needed = total_gpus_needed
for node_id in self._ray_node_ids:
gpus_to_allocate = min(self._ray_node_ids[node_id], gpus_left_needed)
self._ray_node_ids[node_id] -= gpus_to_allocate
gpus_left_needed -= gpus_to_allocate
if gpus_left_needed == 0:
if pipeline_parallel_size == 1:
selected_node_id = node_id
break
# Create a placement group
pg = ray.util.placement_group(
# # Create `tensor_parallel_size` GPU bundles and at least one CPU bundle
# so the actors can be scheduled and executed (1 CPU bundle can have infinite actors):
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html#schedule-tasks-and-actors-to-placement-groups-use-reserved-resources
bundles=[{"CPU": 1.0}] + [{"GPU": 1.0}] * total_gpus_needed,
strategy="SPREAD" if pipeline_parallel_size > 1 else "STRICT_PACK",
_soft_target_node_id=selected_node_id,
)
self._logger.info(
f"Step '{step.name}' uses `vLLM`. Created a Ray placement group with bundle"
f" specs: {pg.bundle_specs}"
)
return ray.util.scheduling_strategies.PlacementGroupSchedulingStrategy( # type: ignore
placement_group=pg,
)
def _teardown(self) -> None:
"""Clean/release/stop resources reserved to run the pipeline."""
if self._write_buffer:
self._write_buffer.close()
if self._batch_manager:
self._batch_manager = None
self._stop_load_queue_loop()
self._load_steps_thread.join()
def _set_steps_not_loaded_exception(self) -> None:
pass
def _stop(self) -> None:
"""Stops the pipeline execution. It will first send `None` to the input queues
of all the steps and then wait until the output queue is empty i.e. all the steps
finished processing the batches that were sent before the stop flag. Then it will
send `None` to the output queue to notify the pipeline to stop."""
with self._stop_called_lock:
if self._stop_called:
self._stop_calls += 1
if self._stop_calls == 1:
self._logger.warning(
"🛑 Press again to force the pipeline to stop."
)
elif self._stop_calls > 1:
self._logger.warning("🛑 Forcing pipeline interruption.")
stop_logging()
sys.exit(1)
return
self._stop_called = True
self._logger.debug(
f"Steps loaded before calling `stop`: {self._steps_load_status}"
)
self._logger.info(
"🛑 Stopping pipeline. Waiting for steps to finish processing batches..."
)
self._stop_output_queue_loop()
def dump(self, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the pipeline information. Override to hardcode the type info to `Pipeline`,
as we don't want to create a `RayPipeline` directly but create it using `Pipeline.ray`
method.
Returns:
The pipeline dump.
"""
from distilabel.pipeline import Pipeline
dict_ = super().dump()
dict_["pipeline"][TYPE_INFO_KEY] = {
"module": Pipeline.__module__,
"name": Pipeline.__name__,
}
return dict_
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from pydantic import BaseModel, PrivateAttr
from typing_extensions import Self
from distilabel.errors import DistilabelUserError
from distilabel.utils.serialization import (
TYPE_INFO_KEY,
_get_module_attr,
_Serializable,
)
if TYPE_CHECKING:
from distilabel.pipeline.batch import _Batch
from distilabel.steps.base import _Step
from distilabel.typing import DownstreamConnectableSteps
RoutingBatchFunc = Callable[[List[str]], List[str]]
"""Type alias for a routing batch function. It takes a list of all the downstream steps and
returns a list with the names of the steps that should receive the batch."""
class RoutingBatchFunction(BaseModel, _Serializable):
"""A thin wrapper around a routing batch function that can be used to route batches
from one upstream step to specific downstream steps.
Attributes:
routing_function: The routing function that takes a list of all the downstream steps
and returns a list with the names of the steps that should receive the batch.
_step: The upstream step that is connected to the routing batch function.
_routed_batch_registry: A dictionary that keeps track of the batches that have been
routed to specific downstream steps.
"""
routing_function: RoutingBatchFunc
description: Optional[str] = None
_step: Union["_Step", None] = PrivateAttr(default=None)
_routed_batch_registry: Dict[str, Dict[int, List[str]]] = PrivateAttr(
default_factory=dict
)
_factory_function_module: Union[str, None] = PrivateAttr(default=None)
_factory_function_name: Union[str, None] = PrivateAttr(default=None)
_factory_function_kwargs: Union[Dict[str, Any], None] = PrivateAttr(default=None)
def route_batch(self, batch: "_Batch", steps: List[str]) -> List[str]:
"""Returns a list of selected downstream steps from `steps` to which the `batch`
should be routed.
Args:
batch: The batch that should be routed.
steps: A list of all the downstream steps that can receive the batch.
Returns:
A list with the names of the steps that should receive the batch.
"""
routed_steps = self.routing_function(steps)
self._register_routed_batch(batch, routed_steps)
return routed_steps
def set_factory_function(
self,
factory_function_module: str,
factory_function_name: str,
factory_function_kwargs: Dict[str, Any],
) -> None:
"""Sets the factory function that was used to create the `routing_batch_function`.
Args:
factory_function_module: The module name where the factory function is defined.
factory_function_name: The name of the factory function that was used to create
the `routing_batch_function`.
factory_function_kwargs: The keyword arguments that were used when calling the
factory function.
"""
self._factory_function_module = factory_function_module
self._factory_function_name = factory_function_name
self._factory_function_kwargs = factory_function_kwargs
def __call__(self, batch: "_Batch", steps: List[str]) -> List[str]:
"""Returns a list of selected downstream steps from `steps` to which the `batch`
should be routed.
Args:
batch: The batch that should be routed.
steps: A list of all the downstream steps that can receive the batch.
Returns:
A list with the names of the steps that should receive the batch.
"""
return self.route_batch(batch, steps)
def _register_routed_batch(self, batch: "_Batch", routed_steps: List[str]) -> None:
"""Registers a batch that has been routed to specific downstream steps.
Args:
batch: The batch that has been routed.
routed_steps: The list of downstream steps that have been selected to receive
the batch.
"""
upstream_step = batch.step_name
batch_seq_no = batch.seq_no
self._routed_batch_registry.setdefault(upstream_step, {}).setdefault(
batch_seq_no, routed_steps
)
def __rshift__(
self, other: List["DownstreamConnectableSteps"]
) -> List["DownstreamConnectableSteps"]:
"""Connects a list of dowstream steps to the upstream step of the routing batch
function.
Args:
other: A list of downstream steps that should be connected to the upstream step
of the routing batch function.
Returns:
The list of downstream steps that have been connected to the upstream step of the
routing batch function.
"""
if not isinstance(other, list):
raise DistilabelUserError(
f"Can only set a `routing_batch_function` for a list of steps. Got: {other}."
" Please, review the right-hand side of the `routing_batch_function >> other`"
" expression. It should be"
" `upstream_step >> routing_batch_function >> [downstream_step_1, dowstream_step_2, ...]`.",
page="sections/how_to_guides/basic/pipeline/?h=routing#routing-batches-to-specific-downstream-steps",
)
if not self._step:
raise DistilabelUserError(
"Routing batch function doesn't have an upstream step. Cannot connect downstream"
" steps before connecting the upstream step. Connect this routing batch"
" function to an upstream step using the `>>` operator. For example:"
" `upstream_step >> routing_batch_function >> [downstream_step_1, downstream_step_2, ...]`.",
page="sections/how_to_guides/basic/pipeline/?h=routing#routing-batches-to-specific-downstream-steps",
)
for step in other:
self._step.connect(step)
return other
def dump(self, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the routing batch function to a dictionary, and the information of the
factory function used to create this routing batch function.
Args:
**kwargs: Additional keyword arguments that should be included in the dump.
Returns:
A dictionary with the routing batch function information and the factory function
information.
"""
dump_info: Dict[str, Any] = {"step": self._step.name} # type: ignore
if self.description:
dump_info["description"] = self.description
if type_info := self._get_type_info():
dump_info[TYPE_INFO_KEY] = type_info
return dump_info
def _get_type_info(self) -> Dict[str, Any]:
"""Returns the information of the factory function used to create the routing batch
function.
Returns:
A dictionary with the factory function information.
"""
type_info = {}
if self._factory_function_module:
type_info["module"] = self._factory_function_module
if self._factory_function_name:
type_info["name"] = self._factory_function_name
if self._factory_function_kwargs:
type_info["kwargs"] = self._factory_function_kwargs
return type_info
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Self:
"""Loads a routing batch function from a dictionary. It must contain the information
of the factory function used to create the routing batch function.
Args:
data: A dictionary with the routing batch function information and the factory
function information.
"""
type_info = data.get(TYPE_INFO_KEY)
if not type_info:
step = data.get("step")
raise ValueError(
f"The routing batch function for step '{step}' was created without a factory"
" function, and it cannot be reconstructed."
)
module = type_info.get("module")
name = type_info.get("name")
kwargs = type_info.get("kwargs")
if not module or not name or not kwargs:
raise ValueError(
"The routing batch function was created with a factory function, but the"
" information is incomplete. Cannot reconstruct the routing batch function."
)
routing_batch_function = _get_module_attr(module=module, name=name)(**kwargs)
routing_batch_function.description = data.get("description")
routing_batch_function.set_factory_function(
factory_function_module=module,
factory_function_name=name,
factory_function_kwargs=kwargs,
)
return routing_batch_function
def routing_batch_function(
description: Optional[str] = None,
) -> Callable[[RoutingBatchFunc], RoutingBatchFunction]:
"""Creates a routing batch function that can be used to route batches from one upstream
step to specific downstream steps.
Args:
description: An optional description for the routing batch function.
Returns:
A `RoutingBatchFunction` instance that can be used with the `>>` operators and with
the `Pipeline.connect` method when defining the pipeline.
Example:
```python
from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import LoadDataFromHub, GroupColumns
@routing_batch_function
def random_routing_batch(steps: List[str]) -> List[str]:
return random.sample(steps, 2)
with Pipeline(name="routing-batch-function") as pipeline:
load_data = LoadDataFromHub()
generations = []
for llm in (
OpenAILLM(model="gpt-4-0125-preview"),
MistralLLM(model="mistral-large-2402"),
VertexAILLM(model="gemini-1.5-pro"),
):
task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
generations.append(task)
combine_columns = GroupColumns(columns=["generation", "model_name"])
load_data >> random_routing_batch >> generations >> combine_columns
```
"""
def decorator(func: RoutingBatchFunc) -> RoutingBatchFunction:
factory_function_name, factory_function_module, factory_function_kwargs = (
None,
None,
None,
)
# Check if `routing_batch_function` was created using a factory function from an installed package
stack = inspect.stack()
if len(stack) > 2:
factory_function_frame_info = stack[1]
# Function factory path
if factory_function_frame_info.function != "<module>":
factory_function_name = factory_function_frame_info.function
factory_function_module = inspect.getmodule(
factory_function_frame_info.frame
).__name__ # type: ignore
# Function factory kwargs
factory_function_kwargs = factory_function_frame_info.frame.f_locals
routing_batch_function = RoutingBatchFunction(
routing_function=func,
description=description,
)
if (
factory_function_module
and factory_function_name
and factory_function_kwargs
):
routing_batch_function.set_factory_function(
factory_function_module=factory_function_module,
factory_function_name=factory_function_name,
factory_function_kwargs=factory_function_kwargs,
)
return routing_batch_function
return decorator
def sample_n_steps(n: int) -> RoutingBatchFunction:
"""A simple function that creates a routing batch function that samples `n` steps from
the list of all the downstream steps.
Args:
n: The number of steps to sample from the list of all the downstream steps.
Returns:
A `RoutingBatchFunction` instance that can be used with the `>>` operators and with
the `Pipeline.connect` method when defining the pipeline.
Example:
```python
from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadDataFromHub, GroupColumns
random_routing_batch = sample_n_steps(2)
with Pipeline(name="routing-batch-function") as pipeline:
load_data = LoadDataFromHub()
generations = []
for llm in (
OpenAILLM(model="gpt-4-0125-preview"),
MistralLLM(model="mistral-large-2402"),
VertexAILLM(model="gemini-1.5-pro"),
):
task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
generations.append(task)
combine_columns = GroupColumns(columns=["generation", "model_name"])
load_data >> random_routing_batch >> generations >> combine_columns
```
"""
@routing_batch_function(
description=f"Sample {n} steps from the list of downstream steps."
)
def sample_n(steps: List[str]) -> List[str]:
return random.sample(steps, n)
return sample_n
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from queue import Queue
from typing import Any, Dict, List, Optional, Union, cast
from distilabel.constants import LAST_BATCH_SENT_FLAG
from distilabel.errors import DISTILABEL_DOCS_URL
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.pipeline.batch import _Batch
from distilabel.steps.base import GeneratorStep, Step, _Step
from distilabel.typing import StepLoadStatus
class _StepWrapper:
"""Wrapper to run the `Step`.
Attributes:
step: The step to run.
replica: The replica ID assigned.
input_queue: The queue to receive the input data.
output_queue: The queue to send the output data.
load_queue: The queue used to notify the main process that the step has been loaded,
has been unloaded or has failed to load.
"""
def __init__(
self,
step: Union["Step", "GeneratorStep"],
replica: int,
input_queue: "Queue[_Batch]",
output_queue: "Queue[_Batch]",
load_queue: "Queue[Union[StepLoadStatus, None]]",
dry_run: bool = False,
ray_pipeline: bool = False,
) -> None:
"""Initializes the `_ProcessWrapper`.
Args:
step: The step to run.
input_queue: The queue to receive the input data.
output_queue: The queue to send the output data.
load_queue: The queue used to notify the main process that the step has been
loaded, has been unloaded or has failed to load.
dry_run: Flag to ensure we are forcing to run the last batch.
ray_pipeline: Whether the step is running a `RayPipeline` or not.
"""
self.step = step
self.replica = replica
self.input_queue = input_queue
self.output_queue = output_queue
self.load_queue = load_queue
self.dry_run = dry_run
self.ray_pipeline = ray_pipeline
self._init_cuda_device_placement()
def _init_cuda_device_placement(self) -> None:
"""Sets the LLM identifier and the number of desired GPUs of the `CudaDevicePlacementMixin`"""
def _init_cuda_device_placement_mixin(attr: CudaDevicePlacementMixin) -> None:
if self.ray_pipeline:
attr.disable_cuda_device_placement = True
else:
desired_num_gpus = self.step.resources.gpus or 1
attr._llm_identifier = f"{self.step.name}-replica-{self.replica}"
attr._desired_num_gpus = desired_num_gpus
for field_name in self.step.model_fields_set:
attr = getattr(self.step, field_name)
if isinstance(attr, CudaDevicePlacementMixin):
_init_cuda_device_placement_mixin(attr)
if isinstance(self.step, CudaDevicePlacementMixin):
_init_cuda_device_placement_mixin(self.step)
def run(self) -> str:
"""The target function executed by the process. This function will also handle
the step lifecycle, executing first the `load` function of the `Step` and then
waiting to receive a batch from the `input_queue` that will be handled by the
`process` method of the `Step`.
Returns:
The name of the step that was executed.
"""
try:
self.step.load()
self.step._logger.debug(f"Step '{self.step.name}' loaded!")
except Exception as e:
self.step.unload()
self._notify_load_failed()
raise _StepWrapperException.create_load_error(
message=f"Step load failed: {e}",
step=self.step,
subprocess_exception=e,
) from e
self._notify_load()
if self.step.is_generator:
self._generator_step_process_loop()
else:
self._non_generator_process_loop()
# Just in case `None` sentinel was sent
# try:
# self.input_queue.get(block=False)
# except Exception:
# pass
self.step.unload()
self._notify_unload()
self.step._logger.info(
f"🏁 Finished running step '{self.step.name}' (replica ID: {self.replica})"
)
return self.step.name # type: ignore
def _notify_load(self) -> None:
"""Notifies that the step has finished executing its `load` function successfully."""
self.step._logger.debug(
f"Notifying load of step '{self.step.name}' (replica ID {self.replica})..."
)
self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore
def _notify_unload(self) -> None:
"""Notifies that the step has been unloaded."""
self.step._logger.debug(
f"Notifying unload of step '{self.step.name}' (replica ID {self.replica})..."
)
self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore
def _notify_load_failed(self) -> None:
"""Notifies that the step failed to load."""
self.step._logger.debug(
f"Notifying load failed of step '{self.step.name}' (replica ID {self.replica})..."
)
self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore
def _generator_step_process_loop(self) -> None:
"""Runs the process loop for a generator step. It will call the `process` method
of the step and send the output data to the `output_queue` and block until the next
batch request is received (i.e. receiving an empty batch from the `input_queue`).
If the `last_batch` attribute of the batch is `True`, the loop will stop and the
process will finish.
Raises:
_StepWrapperException: If an error occurs during the execution of the
`process` method.
"""
step = cast("GeneratorStep", self.step)
try:
if (batch := self.input_queue.get()) is None:
self.step._logger.info(
f"🛑 Stopping yielding batches from step '{self.step.name}'"
)
return
offset = batch.seq_no * step.batch_size # type: ignore
self.step._logger.info(
f"🚰 Starting yielding batches from generator step '{self.step.name}'."
f" Offset: {offset}"
)
for data, last_batch in step.process_applying_mappings(offset=offset):
batch.set_data([data])
batch.last_batch = self.dry_run or last_batch
self._send_batch(batch)
if batch.last_batch:
return
self.step._logger.debug(
f"Step '{self.step.name}' waiting for next batch request..."
)
if (batch := self.input_queue.get()) is None:
self.step._logger.info(
f"🛑 Stopping yielding batches from step '{self.step.name}'"
)
return
except Exception as e:
raise _StepWrapperException(str(e), self.step, 2, e) from e
def _non_generator_process_loop(self) -> None:
"""Runs the process loop for a non-generator step. It will call the `process`
method of the step and send the output data to the `output_queue` and block until
the next batch is received from the `input_queue`. If the `last_batch` attribute
of the batch is `True`, the loop will stop and the process will finish.
If an error occurs during the execution of the `process` method and the step is
global, the process will raise a `_StepWrapperException`. If the step is not
global, the process will log the error and send an empty batch to the `output_queue`.
Raises:
_StepWrapperException: If an error occurs during the execution of the
`process` method and the step is global.
"""
step = cast("Step", self.step)
while True:
if (batch := self.input_queue.get()) is None:
self.step._logger.info(
f"🛑 Stopping processing batches from step '{self.step.name}' (replica"
f" ID: {self.replica})"
)
break
if batch == LAST_BATCH_SENT_FLAG:
self.step._logger.debug("Received `LAST_BATCH_SENT_FLAG`. Stopping...")
break
self.step._logger.info(
f"📦 Processing batch {batch.seq_no} in '{batch.step_name}' (replica ID: {self.replica})"
)
if batch.data_path is not None:
self.step._logger.debug(f"Reading batch data from '{batch.data_path}'")
batch.read_batch_data_from_fs()
result = []
try:
if self.step.has_multiple_inputs:
result = next(step.process_applying_mappings(*batch.data))
else:
result = next(step.process_applying_mappings(batch.data[0]))
except Exception as e:
if self.step.is_global:
self.step.unload()
self._notify_unload()
data = (
batch.data
if isinstance(
e, DistilabelOfflineBatchGenerationNotFinishedException
)
else None
)
raise _StepWrapperException(str(e), self.step, 2, e, data) from e
# Impute step outputs columns with `None`
result = self._impute_step_outputs(batch)
# if the step is not global then we can skip the batch which means sending
# an empty batch to the output queue
self.step._logger.warning(
f"⚠️ Processing batch {batch.seq_no} with step '{self.step.name}' failed."
" Sending empty batch filled with `None`s..."
)
self.step._logger.warning(
f"Subprocess traceback:\n\n{traceback.format_exc()}"
)
finally:
batch.set_data([result])
self._send_batch(batch)
if batch.last_batch:
break
def _impute_step_outputs(self, batch: "_Batch") -> List[Dict[str, Any]]:
"""Imputes the step outputs columns with `None` in the batch data.
Args:
batch: The batch to impute.
"""
return self.step.impute_step_outputs(batch.data[0])
def _send_batch(self, batch: _Batch) -> None:
"""Sends a batch to the `output_queue`."""
if batch.data_path is not None:
self.step._logger.debug(f"Writing batch data to '{batch.data_path}'")
batch.write_batch_data_to_fs()
self.step._logger.info(
f"📨 Step '{batch.step_name}' sending batch {batch.seq_no} to output queue"
)
self.output_queue.put(batch)
class _StepWrapperException(Exception):
"""Exception to be raised when an error occurs in the `_StepWrapper` class.
Attributes:
message: The error message.
step: The `Step` that raised the error.
code: The error code.
subprocess_exception: The exception raised by the subprocess.
data: The data that caused the error. Defaults to `None`.
"""
def __init__(
self,
message: str,
step: "_Step",
code: int,
subprocess_exception: Exception,
data: Optional[List[List[Dict[str, Any]]]] = None,
) -> None:
self.message = f"{message}\n\nFor further information visit '{DISTILABEL_DOCS_URL}api/pipeline/step_wrapper'"
self.step = step
self.code = code
self.subprocess_exception = subprocess_exception
self.formatted_traceback = "".join(
traceback.format_exception(
type(subprocess_exception),
subprocess_exception,
subprocess_exception.__traceback__,
)
)
self.data = data
@classmethod
def create_load_error(
cls,
message: str,
step: "_Step",
subprocess_exception: Optional[Exception] = None,
) -> "_StepWrapperException":
"""Creates a `_StepWrapperException` for a load error.
Args:
message: The error message.
step: The `Step` that raised the error.
subprocess_exception: The exception raised by the subprocess. Defaults to `None`.
Returns:
The `_StepWrapperException` instance.
"""
return cls(message, step, 1, subprocess_exception, None)
@property
def is_load_error(self) -> bool:
"""Whether the error is a load error.
Returns:
`True` if the error is a load error, `False` otherwise.
"""
return self.code == 1
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .instruction import InstructionResponsePipeline # noqa: F401
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from distilabel.distiset import Distiset
from distilabel.llms import LLM, InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import MagpieGenerator
MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
class InstructionResponsePipeline:
"""Generates instructions and responses for a given system prompt.
This example pipeline can be used for a Supervised Fine-Tuning dataset which you
could use to train or evaluate a model. The pipeline generates instructions using the
MagpieGenerator and responses for a given system prompt. The pipeline then keeps only
the instruction, response, and model_name columns.
References:
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
Example:
Generate instructions and responses for a given system prompt:
```python
from distilabel.pipeline import InstructionResponsePipeline
pipeline = InstructionResponsePipeline()
distiset = pipeline.run()
```
Customizing the pipeline further:
```python
from distilabel.pipeline import InstructionResponsePipeline
pipeline = InstructionResponsePipeline(
system_prompt="You are a creative AI Assistant for writing science fiction.",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.2-3B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.2-3B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
),
num_rows=500,
batch_size=2,
n_turns=2,
)
```
"""
def __init__(
self,
llm: Optional[LLM] = None,
system_prompt: str = "You are a creative AI Assistant writer.",
hf_token: Optional[str] = None,
n_turns: int = 1,
num_rows: int = 10,
batch_size: int = 1,
) -> None:
if llm is None:
self.llm: LLM = InferenceEndpointsLLM(
model_id=MODEL,
tokenizer_id=MODEL,
magpie_pre_query_template="llama3",
generation_kwargs={
"temperature": 0.9,
"do_sample": True,
"max_new_tokens": 2048,
"stop_sequences": [
"<|eot_id|>",
"<|start_header_id|>",
"assistant",
" \n\n",
],
},
api_key=hf_token,
)
else:
self.llm = llm
self.pipeline: Pipeline = self._get_magpie_pipeline(
system_prompt=system_prompt,
n_turns=n_turns,
num_rows=num_rows,
batch_size=batch_size,
)
def run(self, **kwargs) -> Distiset:
"""Runs the pipeline and returns a Distiset."""
return self.pipeline.run(**kwargs)
def _get_magpie_pipeline(
self, system_prompt: str, n_turns: int, num_rows: int, batch_size: int
) -> Pipeline:
"""Returns a pipeline that generates instructions and responses for a given system prompt."""
with Pipeline(name="sft") as pipeline:
MagpieGenerator(
llm=self.llm,
n_turns=n_turns,
num_rows=num_rows,
batch_size=batch_size,
system_prompt=system_prompt,
)
return pipeline
def _get_output_columns(self, n_turns: int) -> list:
"""Returns the output mappings for the pipeline."""
if n_turns == 1:
return ["instruction", "response", "model_name"]
else:
return ["instruction", "conversation", "model_name"]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
import pyarrow as pa
import pyarrow.parquet as pq
from distilabel.pipeline.batch import _Batch
from distilabel.utils.dicts import flatten_dict
from distilabel.utils.files import list_files_in_dir
class _WriteBuffer:
"""Class in charge of sending the batched contents to a buffer and writing
those to files under a given folder.
As batches are received, they are added to the buffer and once each buffer
is full, the content is written to a parquet file.
"""
def __init__(
self,
path: "PathLike",
leaf_steps: Set[str],
steps_cached: Optional[Dict[str, bool]] = None,
) -> None:
"""
Args:
path: Folder where the files will be written, the idea
is for this path to be in the cache folder under /data.
leaf_steps: Leaf steps from either the DAG of the Pipeline.
steps_cached: Dictionary with the name of a step and the variable
use_cache. We will use this to determine whether we have to read
a previous parquet table to concatenate before saving the cached
datasets.
Raises:
ValueError: If the path is not a directory.
"""
self._path = Path(path)
if not self._path.exists():
self._path.mkdir(parents=True, exist_ok=True)
for step in leaf_steps:
(self._path / step).mkdir(parents=True, exist_ok=True)
if not self._path.is_dir():
raise ValueError(f"The path should be a directory, not a file: {path}")
self._buffers: Dict[str, List[Dict[str, Any]]] = {
step: [] for step in leaf_steps
}
# TODO: make this configurable
self._buffers_dump_batch_size: Dict[str, int] = {
step: 50 for step in leaf_steps
}
self._buffer_last_schema = {}
self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps}
self._steps_cached = steps_cached or {}
self._logger = logging.getLogger("distilabel.write_buffer")
def _get_filename(self, step_name: str) -> Path:
"""Creates the filename for the step.
Args:
step_name: Name of the step to which the data belongs to.
Returns:
Filename for the step.
"""
return self._path / f"{step_name}.parquet"
def is_full(self, step_name: str) -> bool:
"""Checks the buffers that are full so that those can be written to the file.
Returns:
Whether the buffer is full.
"""
return len(self._buffers[step_name]) >= self._buffers_dump_batch_size[step_name]
def add_batch(self, batch: "_Batch") -> None:
"""Adds a batch to the buffer and writes the buffer to the file if it's full.
Args:
batch: batch to add to the buffer.
"""
step_name = batch.step_name
data = batch.data[0]
self._buffers[step_name].extend(data)
self._logger.debug(
f"Added batch to write buffer for step '{step_name}' with {len(data)} rows."
)
if self.is_full(step_name):
self._logger.debug(
f"Buffer for step '{step_name}' is full (rows: {len(self._buffers[step_name])},"
f" full: {self._buffers_dump_batch_size[step_name]}), writing to file..."
)
self._write(step_name)
def _write(self, step_name: str) -> None:
"""Writes the content to the file and cleans the buffer.
Args:
step_name (str): Name of the step to which the data pertains.
"""
step_parquet_dir = Path(self._path, step_name)
if not step_parquet_dir.exists():
self._logger.debug(
f"Creating directory for step '{step_name}' parquet files..."
)
step_parquet_dir.mkdir()
try:
table = pa.Table.from_pylist(self._buffers[step_name])
except pa.lib.ArrowInvalid as pae:
if (
repr(pae)
!= "ArrowInvalid('cannot mix struct and non-struct, non-null values')"
):
raise pae
flattened_buffers = [flatten_dict(buf) for buf in self._buffers[step_name]]
table = pa.Table.from_pylist(flattened_buffers)
last_schema = self._buffer_last_schema.get(step_name)
if last_schema is None:
self._buffer_last_schema[step_name] = table.schema
else:
if not last_schema.equals(table.schema):
if set(last_schema.names) == set(table.schema.names):
table = table.select(last_schema.names)
else:
new_schema = pa.unify_schemas([last_schema, table.schema])
self._buffer_last_schema[step_name] = new_schema
table = table.cast(new_schema)
next_file_number = self._buffers_last_file[step_name]
self._buffers_last_file[step_name] = next_file_number + 1
parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet"
if parquet_file.exists():
# If the file already exists, due to some error in a pipeline that was cached
prev_table = pq.read_table(parquet_file)
# If some columns differ, it means some of the step changed, we won't load the previous table
# NOTE: If any step has use_cache=False, we cannot assume the previous parquet file is
# valid, so we will overwrite the previous parquet file. Is this the best option?
use_cache = False not in self._steps_cached.values()
if prev_table.column_names == table.column_names and use_cache:
table = pa.concat_tables([prev_table, table])
pq.write_table(table, parquet_file)
self._logger.debug(f"Written to file '{parquet_file}'")
self._clean_buffer(step_name)
def _clean_buffer(self, step_name: str) -> None:
"""Cleans the buffer by setting it's content to `None`.
Args:
step_name: The name of the buffer to clean.
"""
self._buffers[step_name] = []
def close(self) -> None:
"""Closes the buffer by writing the remaining content to the file."""
for step_name in self._buffers:
if self._buffers[step_name]:
self._write(step_name)
# We need to read the parquet files and write them again to ensure the schema
# is correct. Otherwise, the first parquets won't have the last schema and
# then we will have issues when reading them.
for file in list_files_in_dir(self._path / step_name):
if step_name in self._buffer_last_schema:
table = pq.read_table(
file, schema=self._buffer_last_schema[step_name]
)
pq.write_table(table, file)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.steps.argilla.preference import PreferenceToArgilla
from distilabel.steps.argilla.text_generation import TextGenerationToArgilla
from distilabel.steps.base import (
GeneratorStep,
GlobalStep,
Step,
StepInput,
StepResources,
)
from distilabel.steps.clustering.dbscan import DBSCAN
from distilabel.steps.clustering.text_clustering import TextClustering
from distilabel.steps.clustering.umap import UMAP
from distilabel.steps.columns.combine import CombineOutputs
from distilabel.steps.columns.expand import ExpandColumns
from distilabel.steps.columns.group import GroupColumns
from distilabel.steps.columns.keep import KeepColumns
from distilabel.steps.columns.merge import MergeColumns
from distilabel.steps.decorator import step
from distilabel.steps.deita import DeitaFiltering
from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour
from distilabel.steps.filtering.embedding import EmbeddingDedup
from distilabel.steps.filtering.minhash import MinHashDedup
from distilabel.steps.formatting.conversation import ConversationTemplate
from distilabel.steps.formatting.dpo import (
FormatChatGenerationDPO,
FormatTextGenerationDPO,
)
from distilabel.steps.formatting.sft import (
FormatChatGenerationSFT,
FormatTextGenerationSFT,
)
from distilabel.steps.generators.data import LoadDataFromDicts
from distilabel.steps.generators.data_sampler import DataSampler
from distilabel.steps.generators.huggingface import (
LoadDataFromDisk,
LoadDataFromFileSystem,
LoadDataFromHub,
)
from distilabel.steps.generators.utils import make_generator_step
from distilabel.steps.globals.huggingface import PushToHub
from distilabel.steps.reward_model import RewardModelScore
from distilabel.steps.truncate import TruncateTextColumn
from distilabel.typing import GeneratorStepOutput, StepOutput
__all__ = [
"DBSCAN",
"UMAP",
"CombineOutputs",
"ConversationTemplate",
"DataSampler",
"DeitaFiltering",
"EmbeddingDedup",
"EmbeddingGeneration",
"ExpandColumns",
"FaissNearestNeighbour",
"FormatChatGenerationDPO",
"FormatChatGenerationSFT",
"FormatTextGenerationDPO",
"FormatTextGenerationSFT",
"GeneratorStep",
"GeneratorStepOutput",
"GlobalStep",
"GroupColumns",
"KeepColumns",
"LoadDataFromDicts",
"LoadDataFromDisk",
"LoadDataFromFileSystem",
"LoadDataFromHub",
"MergeColumns",
"MinHashDedup",
"PreferenceToArgilla",
"PushToHub",
"RewardModelScore",
"Step",
"StepInput",
"StepOutput",
"StepResources",
"TextClustering",
"TextGenerationToArgilla",
"TruncateTextColumn",
"make_generator_step",
"step",
]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment