Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
from pydantic import Field, PrivateAttr
from typing_extensions import override
from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.image_generation.base import ImageGenerationModel
from distilabel.models.llms.base import LLM
from distilabel.steps.base import (
GeneratorStep,
GlobalStep,
Step,
StepInput,
_Step,
)
from distilabel.utils.dicts import group_dicts
if TYPE_CHECKING:
from distilabel.typing import (
ChatType,
FormattedInput,
GenerateOutput,
LLMStatistics,
StepOutput,
)
class _Task(_Step, ABC):
"""_Task is an abstract class that implements the `_Step` interface and adds the
`format_input` and `format_output` methods to format the inputs and outputs of the
task. It also adds a `llm` attribute to be used as the LLM to generate the outputs.
Attributes:
llm: the `LLM` to be used to generate the outputs of the task.
group_generations: whether to group the `num_generations` generated per input in
a list or create a row per generation. Defaults to `False`.
add_raw_output: whether to include a field with the raw output of the LLM in the
`distilabel_metadata` field of the output. Can be helpful to not loose data
with `Tasks` that need to format the output of the `LLM`. Defaults to `False`.
num_generations: The number of generations to be produced per input.
"""
llm: LLM
group_generations: bool = False
add_raw_output: RuntimeParameter[bool] = Field(
default=True,
description=(
"Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary output column"
),
)
add_raw_input: RuntimeParameter[bool] = Field(
default=True,
description=(
"Whether to include the raw input of the LLM in the key `raw_input_<TASK_NAME>`"
" of the `distilabel_metadata` dictionary column"
),
)
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
use_default_structured_output: bool = False
_can_be_used_with_offline_batch_generation: bool = PrivateAttr(False)
def model_post_init(self, __context: Any) -> None:
if (
self.llm.use_offline_batch_generation
and not self._can_be_used_with_offline_batch_generation
):
raise DistilabelUserError(
f"`{self.__class__.__name__}` task cannot be used with offline batch generation"
" feature.",
page="sections/how_to_guides/advanced/offline-batch-generation",
)
super().model_post_init(__context)
@property
def is_global(self) -> bool:
"""Extends the `is_global` property to return `True` if the task is using the
offline batch generation feature, otherwise it returns the value of the parent
class property. `offline_batch_generation` requires to receive all the inputs
at once, so for the `_BatchManager` this is a global step.
Returns:
Whether the task is a global step or not.
"""
if self.llm.use_offline_batch_generation:
return True
return super().is_global
def load(self) -> None:
"""Loads the LLM via the `LLM.load()` method."""
super().load()
self._set_default_structured_output()
self.llm.load()
@override
def unload(self) -> None:
"""Unloads the LLM."""
self._logger.debug("Executing task unload logic.")
self.llm.unload()
@override
def impute_step_outputs(
self, step_output: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Imputes the outputs of the task in case the LLM failed to generate a response.
"""
result = []
for row in step_output:
data = row.copy()
for output in self.get_outputs().keys():
data[output] = None
data = self._create_metadata(
data,
None,
None,
add_raw_output=self.add_raw_output,
add_raw_input=self.add_raw_input,
)
result.append(data)
return result
@abstractmethod
def format_output(
self,
output: Union[str, None],
input: Union[Dict[str, Any], None] = None,
) -> Dict[str, Any]:
"""Abstract method to format the outputs of the task. It needs to receive an output
as a string, and generates a Python dictionary with the outputs of the task. In
addition the `input` used to generate the output is also received just in case it's
needed to be able to parse the output correctly.
"""
pass
def _format_outputs(
self,
outputs: "GenerateOutput",
input: Union[Dict[str, Any], None] = None,
) -> List[Dict[str, Any]]:
"""Formats the outputs of the task using the `format_output` method. If the output
is `None` (i.e. the LLM failed to generate a response), then the outputs will be
set to `None` as well.
Args:
outputs: The outputs (`n` generations) for the provided `input`.
input: The input used to generate the output.
Returns:
A list containing a dictionary with the outputs of the task for each input.
"""
inputs = [None] if input is None else [input]
formatted_outputs = []
repeate_inputs = len(outputs.get("generations"))
outputs = normalize_statistics(outputs)
for (output, stats, extra), input in zip(
iterate_generations_with_stats(outputs), inputs * repeate_inputs
): # type: ignore
try:
# Extract the generations, and move the statistics to the distilabel_metadata,
# to keep everything clean
formatted_output = self.format_output(output, input)
formatted_output = self._create_metadata(
output=formatted_output,
raw_output=output,
input=input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
statistics=stats,
)
formatted_output = self._create_extra(
output=formatted_output, extra=extra
)
formatted_outputs.append(formatted_output)
except Exception as e:
self._logger.warning( # type: ignore
f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore
)
formatted_outputs.append(self._output_on_failure(output, input))
return formatted_outputs
def _output_on_failure(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
"""In case of failure to format the output, this method will return a dictionary including
a new field `distilabel_meta` with the raw output of the LLM.
"""
# Create a dictionary with the outputs of the task (every output set to None)
outputs = {output: None for output in self.outputs}
outputs["model_name"] = self.llm.model_name # type: ignore
outputs = self._create_metadata(
outputs,
output,
input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
)
return outputs
def _create_metadata(
self,
output: Dict[str, Any],
raw_output: Union[str, None],
input: Union[Dict[str, Any], None] = None,
add_raw_output: bool = True,
add_raw_input: bool = True,
statistics: Optional["LLMStatistics"] = None,
) -> Dict[str, Any]:
"""Adds the raw output and or the formatted input of the LLM to the output dictionary
if `add_raw_output` is True or `add_raw_input` is True.
Args:
output:
The output dictionary after formatting the output from the LLM,
to add the raw output and or raw input.
raw_output: The raw output of the `LLM`.
input: The input used to generate the output.
add_raw_output: Whether to add the raw output to the output dictionary.
add_raw_input: Whether to add the raw input to the output dictionary.
statistics: The statistics generated by the LLM, which should contain at least
the number of input and output tokens.
"""
meta = output.get(DISTILABEL_METADATA_KEY, {})
if add_raw_output:
meta[f"raw_output_{self.name}"] = raw_output
if add_raw_input:
meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None
if statistics:
meta[f"statistics_{self.name}"] = statistics
if meta:
output[DISTILABEL_METADATA_KEY] = meta
return output
def _create_extra(
self, output: Dict[str, Any], extra: Dict[str, Any]
) -> Dict[str, Any]:
column_name_prefix = f"llm_{self.name}_"
for key, value in extra.items():
column_name = column_name_prefix + key
output[column_name] = value
return output
def _set_default_structured_output(self) -> None:
"""Prepares the structured output to be set in the selected `LLM`.
If the method `get_structured_output` returns None (the default), there's no need
to set anything, as it doesn't apply.
If the `use_default_structured_output` and there's no previous structured output
set by hand, then decide the type of structured output to select depending on the
`LLM` provider.
"""
schema = self.get_structured_output()
if not schema:
return
if self.use_default_structured_output and not self.llm.structured_output:
# In case the default structured output is required, we have to set it before
# the LLM is loaded
from distilabel.models.llms import InferenceEndpointsLLM
from distilabel.models.llms.base import AsyncLLM
def check_dependency(module_name: str) -> None:
if not importlib.util.find_spec(module_name):
raise ImportError(
f"`{module_name}` is not installed and is needed for the structured generation with this LLM."
f" Please install it using `pip install {module_name}`."
)
dependency = "outlines"
structured_output = {"schema": schema}
if isinstance(self.llm, InferenceEndpointsLLM):
structured_output.update({"format": "json"})
# To determine instructor or outlines format
elif isinstance(self.llm, AsyncLLM) and not isinstance(
self.llm, InferenceEndpointsLLM
):
dependency = "instructor"
structured_output.update({"format": "json"})
check_dependency(dependency)
self.llm.structured_output = structured_output
def get_structured_output(self) -> Union[Dict[str, Any], None]:
"""Returns the structured output for a task that implements one by default,
must be overriden by subclasses of `Task`. When implemented, should be a json
schema that enforces the response from the LLM so that it's easier to parse.
"""
return None
def _sample_input(self) -> "ChatType":
"""Returns a sample input to be used in the `print` method.
Tasks that don't adhere to a format input that returns a map of the type
str -> str should override this method to return a sample input.
"""
return self.format_input(
{input: f"<PLACEHOLDER_{input.upper()}>" for input in self.inputs}
)
def print(self, sample_input: Optional["ChatType"] = None) -> None:
"""Prints a sample input to the console using the `rich` library.
Helper method to visualize the prompt of the task.
Args:
sample_input: A sample input to be printed. If not provided, a default will be
generated using the `_sample_input` method, which can be overriden by
subclasses. This should correspond to the same example you could pass to
the `format_input` method.
The variables be named <PLACEHOLDER_VARIABLE_NAME> by default.
Examples:
Print the URIAL prompt:
```python
from distilabel.steps.tasks import URIAL
from distilabel.models.llms.huggingface import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
urial = URIAL(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
)
urial.load()
urial.print()
╭─────────────────────────────────────── Prompt: URIAL ────────────────────────────────────────╮
│ ╭────────────────────────────────────── User Message ───────────────────────────────────────╮ │
│ │ # Instruction │ │
│ │ │ │
│ │ Below is a list of conversations between a human and an AI assistant (you). │ │
│ │ Users place their queries under "# User:", and your responses are under "# Assistant:". │ │
│ │ You are a helpful, respectful, and honest assistant. │ │
│ │ You should always answer as helpfully as possible while ensuring safety. │ │
│ │ Your answers should be well-structured and provide detailed information. They should also │ │
│ │ have an engaging tone. │ │
│ │ Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, │ │
│ │ dangerous, or illegal content, even if it may be helpful. │ │
│ │ Your response must be socially responsible, and thus you can refuse to answer some │ │
│ │ controversial topics. │ │
│ │ │ │
│ │ │ │
│ │ # User: │ │
│ │ │ │
│ │ <PLACEHOLDER_INSTRUCTION> │ │
│ │ │ │
│ │ # Assistant: │ │
│ ╰───────────────────────────────────────────────────────────────────────────────────────────╯ │
╰───────────────────────────────────────────────────────────────────────────────────────────────╯
```
"""
from rich.console import Console, Group
from rich.panel import Panel
from rich.text import Text
console = Console()
sample_input = sample_input or self._sample_input()
panels = []
for item in sample_input:
content = Text.assemble((item.get("content", ""),))
panel = Panel(
content,
title=f"[bold][magenta]{item.get('role', '').capitalize()} Message[/magenta][/bold]",
border_style="light_cyan3",
)
panels.append(panel)
# Create a group of panels
# Wrap the group in an outer panel
outer_panel = Panel(
Group(*panels),
title=f"[bold][magenta]Prompt: {type(self).__name__} [/magenta][/bold]",
border_style="light_cyan3",
expand=False,
)
console.print(outer_panel)
class Task(_Task, Step):
"""Task is a class that implements the `_Task` abstract class and adds the `Step`
interface to be used as a step in the pipeline.
Attributes:
llm: the `LLM` to be used to generate the outputs of the task.
group_generations: whether to group the `num_generations` generated per input in
a list or create a row per generation. Defaults to `False`.
num_generations: The number of generations to be produced per input.
"""
@abstractmethod
def format_input(self, input: Dict[str, Any]) -> "FormattedInput":
"""Abstract method to format the inputs of the task. It needs to receive an input
as a Python dictionary, and generates an OpenAI chat-like list of dicts."""
pass
def _format_inputs(self, inputs: List[Dict[str, Any]]) -> List["FormattedInput"]:
"""Formats the inputs of the task using the `format_input` method.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Returns:
A list containing the formatted inputs, which are `ChatType`-like following
the OpenAI formatting.
"""
return [self.format_input(input) for input in inputs]
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Processes the inputs of the task and generates the outputs using the LLM.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Yields:
A list of Python dictionaries with the outputs of the task.
"""
formatted_inputs = self._format_inputs(inputs)
# `outputs` is a dict containing the LLM outputs in the `generations`
# key and the statistics in the `statistics` key
outputs = self.llm.generate_outputs(
inputs=formatted_inputs,
num_generations=self.num_generations, # type: ignore
**self.llm.get_generation_kwargs(), # type: ignore
)
task_outputs = []
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, input)
if self.group_generations:
combined = group_dicts(*formatted_outputs)
task_outputs.append(
{**input, **combined, "model_name": self.llm.model_name}
)
continue
# Create a row per generation
for formatted_output in formatted_outputs:
task_outputs.append(
{**input, **formatted_output, "model_name": self.llm.model_name}
)
yield task_outputs
class GeneratorTask(_Task, GeneratorStep):
"""`GeneratorTask` is a class that implements the `_Task` abstract class and adds the
`GeneratorStep` interface to be used as a step in the pipeline.
Attributes:
llm: the `LLM` to be used to generate the outputs of the task.
group_generations: whether to group the `num_generations` generated per input in
a list or create a row per generation. Defaults to `False`.
num_generations: The number of generations to be produced per input.
"""
pass
class GlobalTask(_Task, GlobalStep):
"""`GlobalTask` is a class that implements the `_Task` abstract class and adds the
`GlobalStep` interface to be used as a step in the pipeline. It's generally used in
combination with `LLM`s that can be used for offline batched inference.
"""
pass
class ImageTask(_Task, Step):
"""`ImageTask` is a class that implements the `_Task` abstract class and adds the `Step`
interface to be used as a step in the pipeline. It differs from the `Task` in that it's
expected to work with `ImageGenerationModel`s instead of `LLM`s.
Attributes:
image_generation_model: the `ImageGenerationModel` to be used to generate the outputs.
llm: This attribute is here to respect the `_Task` interface, but it's used internally only.
group_generations: whether to group the `num_generations` generated per input in
a list or create a row per generation. Defaults to `False`.
num_generations: The number of generations to be produced per input.
"""
llm: Union[LLM, ImageGenerationModel, None] = None
image_generation_model: ImageGenerationModel
def model_post_init(self, __context: Any) -> None:
assert self.llm is None, (
"`ImageTask` cannot use an `LLM` attribute given by the user, pass "
"the `image_generation_model` attribute instead."
)
self.llm = self.image_generation_model
# Call the post init from the Step, as we don't want to call specific behaviour
# from the task, that may need to deal with specific attributes from the LLM
# not in the ImageGenerationModel
super(Step, self).model_post_init(__context)
@abstractmethod
def format_input(self, input: dict[str, any]) -> str:
"""Abstract method to format the inputs of the task. It needs to receive an input
as a Python dictionary, and generates a string to be used as the prompt for the model."""
pass
def _format_inputs(self, inputs: list[dict[str, any]]) -> List["FormattedInput"]:
"""Formats the inputs of the task using the `format_input` method.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Returns:
A list containing the formatted inputs, which are `ChatType`-like following
the OpenAI formatting.
"""
return [self.format_input(input) for input in inputs]
def _format_outputs(
self,
outputs: list[Union[str, None]],
input: Union[Dict[str, Any], None] = None,
) -> List[Dict[str, Any]]:
"""Formats the outputs of the task using the `format_output` method. If the output
is `None` (i.e. the LLM failed to generate a response), then the outputs will be
set to `None` as well.
Args:
outputs: The outputs (`n` generations) for the provided `input`.
input: The input used to generate the output.
Returns:
A list containing a dictionary with the outputs of the task for each input.
"""
inputs = [None] if input is None else [input]
formatted_outputs = []
for output, input in zip(outputs, inputs): # type: ignore
try:
formatted_output = self.format_output(output, input)
formatted_output = self._create_metadata(
formatted_output,
output,
input,
add_raw_output=self.add_raw_output, # type: ignore
add_raw_input=self.add_raw_input, # type: ignore
statistics=None,
)
formatted_outputs.append(formatted_output)
except Exception as e:
self._logger.warning( # type: ignore
f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore
)
formatted_outputs.append(self._output_on_failure(output, input))
return formatted_outputs
@abstractmethod
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Yields:
A list of Python dictionaries with the outputs of the task.
"""
pass
def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput":
"""Transforms the GenerateOutput statistics to have the same length as the generations.
Args:
data: A generate output that possibly has different lengths of statistics
vs generations (due to num_generations=3 returning 3 generations, but
for example the tokens are only counted once).
Returns:
Normalized statistics according to the generations length.
Examples:
```python
data = {
"generations": ["text1", "text2", "text3", "text4"],
"statistics": {"input_tokens": [1], "output_tokens": [1, 2, 3]}
}
normalize_statistics(data)
data = {
"generations": ["text1", "text2", "text3"],
"statistics": {"input_tokens": [1, 1, 1], "output_tokens": [1, 2, 3]}
}
```
"""
statistics = output.get("statistics")
if not statistics:
return output
gen_length = len(output["generations"])
for stat_key, stat_values in output["statistics"].items():
current_length = len(stat_values) # type: ignore
if current_length > 0 and current_length < gen_length:
# Calculate how many times to repeat the tokens
repeats = gen_length // current_length
remainder = gen_length % current_length
# Create new list with repeated values
new_values = stat_values * repeats + stat_values[:remainder] # type: ignore
output["statistics"][stat_key] = new_values
return output
def iterate_generations_with_stats(
outputs: "GenerateOutput",
) -> Generator[Tuple[Union[str, None], "LLMStatistics", Dict[str, Any]], None, None]:
"""Helper function to iterate together generations and statistics while processing
them inside `_format_outputs`.
Args:
outputs: outputs from the `LLM.generate_outputs` method.
Yields:
Iterator of generation, generation statistics and extra data generated by the `LLM`.
"""
extra_keys = [
key for key in outputs.keys() if key not in ("generations", "statistics")
]
for i, generation in enumerate(outputs["generations"]):
# Create a new dictionary with the statistics for this index
stats = {
key: values[i] # type: ignore
for key, values in outputs["statistics"].items()
if values
}
# Extra keys returned by the `LLM`
extra = {key: outputs[key][i] for key in extra_keys}
yield generation, stats, extra
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.resources as importlib_resources
from typing import TYPE_CHECKING, Any, Dict, Final, Union
from jinja2 import Template
from pydantic import PrivateAttr
from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
from distilabel.typing import ChatType, StepColumns
SYSTEM_PROMPT: Final[str] = (
"You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."
)
class CLAIR(Task):
r"""Contrastive Learning from AI Revisions (CLAIR).
CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting
preference A `preferred` A’ is much more contrastive and precise.
Input columns:
- task (`str`): The task or instruction.
- student_solution (`str`): An answer to the task that is to be revised.
Output columns:
- revision (`str`): The revised text.
- rational (`str`): The rational for the provided revision.
- model_name (`str`): The name of the model used to generate the revision and rational.
Categories:
- preference
- text-generation
References:
- [`Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment`](https://arxiv.org/abs/2408.06266v1)
- [`APO and CLAIR - GitHub Repository`](https://github.com/ContextualAI/CLAIR_and_APO)
Examples:
Create contrastive preference pairs:
```python
from distilabel.steps.tasks import CLAIR
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 4096,
},
)
clair_task = CLAIR(llm=llm)
clair_task.load()
result = next(
clair_task.process(
[
{
"task": "How many gaps are there between the earth and the moon?",
"student_solution": 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.'
}
]
)
)
# result
# [{'task': 'How many gaps are there between the earth and the moon?',
# 'student_solution': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.',
# 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
# 'rational': 'The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.',
# 'distilabel_metadata': {'raw_output_c_l_a_i_r_0': '{teacher_reasoning}: The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.\n\n{corrected_student_solution}: There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
# 'raw_input_c_l_a_i_r_0': [{'role': 'system',
# 'content': "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."},
# {'role': 'user',
# 'content': '{task}: How many gaps are there between the earth and the moon?\n\n{student_solution}: There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.\n\n-----------------\n\nLet\'s first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```
Citations:
```
@misc{doosterlinck2024anchoredpreferenceoptimizationcontrastive,
title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment},
author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri},
year={2024},
eprint={2408.06266},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2408.06266},
}
```
"""
system_prompt: str = SYSTEM_PROMPT
_template: Union[Template, None] = PrivateAttr(...)
def load(self) -> None:
super().load()
_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "clair.jinja2"
)
with open(_path, "r") as f:
self._template = Template(f.read())
@property
def inputs(self) -> "StepColumns":
return ["task", "student_solution"]
@property
def outputs(self) -> "StepColumns":
return ["revision", "rational", "model_name"]
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""
return [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": self._template.render(
task=input["task"], student_solution=input["student_solution"]
),
},
]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
"""The output is formatted as a list with the score of each instruction-response pair.
Args:
output: the raw output of the LLM.
input: the input to the task. Used for obtaining the number of responses.
Returns:
A dict with the key `scores` containing the scores for each instruction-response pair.
"""
if output is None:
return self._default_error()
return self._format_output(output)
def _format_output(self, output: Union[str, None]) -> Dict[str, Any]:
if "**Corrected Student Solution:**" in output:
splits = output.split("**Corrected Student Solution:**")
elif "{corrected_student_solution}:" in output:
splits = output.split("{corrected_student_solution}:")
elif "{corrected_student_solution}" in output:
splits = output.split("{corrected_student_solution}")
elif "**Worsened Student Solution:**" in output:
splits = output.split("**Worsened Student Solution:**")
elif "{worsened_student_solution}:" in output:
splits = output.split("{worsened_student_solution}:")
elif "{worsened_student_solution}" in output:
splits = output.split("{worsened_student_solution}")
else:
splits = None
# Safety check when the output doesn't follow the expected format
if not splits:
return self._default_error()
if len(splits) >= 2:
revision = splits[1]
revision = revision.strip("\n\n").strip() # noqa: B005
rational = splits[0]
if "{teacher_reasoning}" in rational:
rational = rational.split("{teacher_reasoning}")[1].strip(":").strip()
rational = rational.strip("\n\n").strip() # noqa: B005
else:
return self._default_error()
return {"revision": revision, "rational": rational}
def _default_error(self) -> Dict[str, None]:
return {"revision": None, "rational": None}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sys
if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources
from typing import TYPE_CHECKING, Any, Dict, List, Union
import orjson
from jinja2 import Template
from pydantic import PrivateAttr
from typing_extensions import override
from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
from distilabel.typing import ChatType
_PARSE_SCORE_LINE_REGEX = re.compile(r"\[\d+\] score: (\d+)", re.IGNORECASE)
class ComplexityScorer(Task):
"""Score instructions based on their complexity using an `LLM`.
`ComplexityScorer` is a pre-defined task used to rank a list of instructions based in
their complexity. It's an implementation of the complexity score task from the paper
'What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection
in Instruction Tuning'.
Attributes:
_template: a Jinja2 template used to format the input for the LLM.
Input columns:
- instructions (`List[str]`): The list of instructions to be scored.
Output columns:
- scores (`List[float]`): The score for each instruction.
- model_name (`str`): The model name used to generate the scores.
Categories:
- scorer
- complexity
- instruction
References:
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
Examples:
Evaluate the complexity of your instructions:
```python
from distilabel.steps.tasks import ComplexityScorer
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
)
)
scorer.load()
result = next(
scorer.process(
[{"instructions": ["plain instruction", "highly complex instruction"]}]
)
)
# result
# [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 5], 'distilabel_metadata': {'raw_output_complexity_scorer_0': 'output'}}]
```
Generate structured output with default schema:
```python
from distilabel.steps.tasks import ComplexityScorer
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
use_default_structured_output=use_default_structured_output
)
scorer.load()
result = next(
scorer.process(
[{"instructions": ["plain instruction", "highly complex instruction"]}]
)
)
# result
# [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 2], 'distilabel_metadata': {'raw_output_complexity_scorer_0': '{ \\n "scores": [\\n 1, \\n 2\\n ]\\n}'}}]
```
Citations:
```
@misc{liu2024makesgooddataalignment,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
author={Wei Liu and Weihao Zeng and Keqing He and Yong Jiang and Junxian He},
year={2024},
eprint={2312.15685},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2312.15685},
}
```
"""
_template: Union[Template, None] = PrivateAttr(...)
_can_be_used_with_offline_batch_generation = True
def load(self) -> None:
"""Loads the Jinja2 template."""
super().load()
_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "complexity-scorer.jinja2"
)
self._template = Template(open(_path).read())
@property
def inputs(self) -> List[str]:
"""The inputs for the task are the `instructions`."""
return ["instructions"]
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation."""
return [
{
"role": "user",
"content": self._template.render(instructions=input["instructions"]), # type: ignore
}
]
@property
def outputs(self) -> List[str]:
"""The output for the task are: a list of `scores` containing the complexity score for each
instruction in `instructions`, and the `model_name`."""
return ["scores", "model_name"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
"""The output is formatted as a list with the score of each instruction.
Args:
output: the raw output of the LLM.
input: the input to the task. Used for obtaining the number of responses.
Returns:
A dict with the key `scores` containing the scores for each instruction.
"""
if output is None:
return {"scores": [None] * len(input["instructions"])}
if self.use_default_structured_output:
return self._format_structured_output(output, input)
scores = []
score_lines = output.split("\n")
for i, line in enumerate(score_lines):
match = _PARSE_SCORE_LINE_REGEX.search(line)
score = float(match.group(1)) if match else None
scores.append(score)
if i == len(input["instructions"]) - 1:
break
return {"scores": scores}
@override
def get_structured_output(self) -> Dict[str, Any]:
"""Creates the json schema to be passed to the LLM, to enforce generating
a dictionary with the output which can be directly parsed as a python dictionary.
The schema corresponds to the following:
```python
from pydantic import BaseModel
from typing import List
class SchemaComplexityScorer(BaseModel):
scores: List[int]
```
Returns:
JSON Schema of the response to enforce.
"""
return {
"properties": {
"scores": {
"items": {"type": "integer"},
"title": "Scores",
"type": "array",
}
},
"required": ["scores"],
"title": "SchemaComplexityScorer",
"type": "object",
}
def _format_structured_output(
self, output: str, input: Dict[str, Any]
) -> Dict[str, str]:
"""Parses the structured response, which should correspond to a dictionary
with either `positive`, or `positive` and `negative` keys.
Args:
output: The output from the `LLM`.
Returns:
Formatted output.
"""
try:
return orjson.loads(output)
except orjson.JSONDecodeError:
return {"scores": [None] * len(input["instructions"])}
@override
def _sample_input(self) -> "ChatType":
"""Returns a sample input to be used in the `print` method.
Tasks that don't adhere to a format input that returns a map of the type
str -> str should override this method to return a sample input.
"""
return self.format_input(
{
"instructions": [
f"<PLACEHOLDER_{f'GENERATION_{i}'.upper()}>" for i in range(2)
],
}
)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Tuple, Type, Union
import yaml
from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.base import Task
from distilabel.typing import FormattedInput
if TYPE_CHECKING:
from distilabel.typing import StepColumns
TaskFormattingOutputFunc = Callable[..., Dict[str, Any]]
def task(
inputs: Union["StepColumns", None] = None,
outputs: Union["StepColumns", None] = None,
) -> Callable[..., Type["Task"]]:
"""Creates a `Task` from a formatting output function.
Args:
inputs: a list containing the name of the inputs columns/keys or a dictionary
where the keys are the columns and the values are booleans indicating whether
the column is required or not, that are required by the step. If not provided
the default will be an empty list `[]` and it will be assumed that the step
doesn't need any specific columns. Defaults to `None`.
outputs: a list containing the name of the outputs columns/keys or a dictionary
where the keys are the columns and the values are booleans indicating whether
the column will be generated or not. If not provided the default will be an
empty list `[]` and it will be assumed that the step doesn't need any specific
columns. Defaults to `None`.
"""
inputs = inputs or []
outputs = outputs or []
def decorator(func: TaskFormattingOutputFunc) -> Type["Task"]:
doc = inspect.getdoc(func)
if doc is None:
raise DistilabelUserError(
"When using the `task` decorator, including a docstring in the formatting"
" function is mandatory. The docstring must follow the format described"
" in the documentation.",
page="",
)
system_prompt, user_message_template = _parse_docstring(doc)
_validate_templates(inputs, system_prompt, user_message_template)
def inputs_property(self) -> "StepColumns":
return inputs
def outputs_property(self) -> "StepColumns":
return outputs
def format_input(self, input: Dict[str, Any]) -> "FormattedInput":
return [
{"role": "system", "content": system_prompt.format(**input)},
{"role": "user", "content": user_message_template.format(**input)},
]
def format_output(
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
) -> Dict[str, Any]:
return func(output, input)
return type(
func.__name__,
(Task,),
{
"inputs": property(inputs_property),
"outputs": property(outputs_property),
"__module__": func.__module__,
"format_input": format_input,
"format_output": format_output,
},
)
return decorator
_SYSTEM_PROMPT_YAML_KEY: Final[str] = "system_prompt"
_USER_MESSAGE_TEMPLATE_YAML_KEY: Final[str] = "user_message_template"
_DOCSTRING_FORMATTING_FUNCTION_ERROR: Final[str] = (
"Formatting function decorated with `task` doesn't follow the expected format. Please,"
" check the documentation and update the function to include a docstring with the expected"
" format."
)
def _parse_docstring(docstring: str) -> Tuple[str, str]:
"""Parses the docstring of the formatting function that was built using the `task`
decorator.
Args:
docstring: the docstring of the formatting function.
Returns:
A tuple containing the system prompt and the user message template.
Raises:
DistilabelUserError: if the docstring doesn't follow the expected format or if
the expected keys are missing.
"""
parts = docstring.split("---")
if len(parts) != 3:
raise DistilabelUserError(
_DOCSTRING_FORMATTING_FUNCTION_ERROR,
page="",
)
yaml_content = parts[1]
try:
parsed_yaml = yaml.safe_load(yaml_content)
if not isinstance(parsed_yaml, dict):
raise DistilabelUserError(
_DOCSTRING_FORMATTING_FUNCTION_ERROR,
page="",
)
system_prompt = parsed_yaml.get(_SYSTEM_PROMPT_YAML_KEY)
user_template = parsed_yaml.get(_USER_MESSAGE_TEMPLATE_YAML_KEY)
if system_prompt is None or user_template is None:
raise DistilabelUserError(
"The formatting function decorated with `task` must include both the `system_prompt`"
" and `user_message_template` keys in the docstring. Please, check the documentation"
" and update the docstring of the formatting function to include the expected"
" keys.",
page="",
)
return system_prompt.strip(), user_template.strip()
except yaml.YAMLError as e:
raise DistilabelUserError(_DOCSTRING_FORMATTING_FUNCTION_ERROR, page="") from e
TEMPLATE_PLACEHOLDERS_REGEX = re.compile(r"\{(\w+)\}")
def _validate_templates(
inputs: "StepColumns", system_prompt: str, user_message_template: str
) -> None:
"""Validates the system prompt and user message template to ensure that they only
contain the allowed placeholders i.e. the columns/keys that are provided as inputs.
Args:
inputs: the list of inputs columns/keys.
system_prompt: the system prompt.
user_message_template: the user message template.
Raises:
DistilabelUserError: if the system prompt or the user message template contain
invalid placeholders.
"""
list_inputs = list(inputs.keys()) if isinstance(inputs, dict) else inputs
valid_system_prompt, invalid_system_prompt_placeholders = _validate_template(
system_prompt, list_inputs
)
if not valid_system_prompt:
raise DistilabelUserError(
f"The formatting function decorated with `task` includes invalid placeholders"
f" in the extracted `system_prompt` from the function docstring. Valid placeholders"
f" are: {list_inputs}, but the following placeholders were found: {invalid_system_prompt_placeholders}."
f" Please, update the `system_prompt` to only include the valid placeholders.",
page="",
)
valid_user_message_template, invalid_user_message_template_placeholders = (
_validate_template(user_message_template, list_inputs)
)
if not valid_user_message_template:
raise DistilabelUserError(
f"The formatting function decorated with `task` includes invalid placeholders"
f" in the extracted `user_message_template` from the function docstring. Valid"
f" placeholders are: {list_inputs}, but the following placeholders were found:"
f" {invalid_user_message_template_placeholders}. Please, update the `system_prompt`"
" to only include the valid placeholders.",
page="",
)
def _validate_template(
template: str, allowed_placeholders: List[str]
) -> Tuple[bool, set[str]]:
"""Validates that the template only contains the allowed placeholders.
Args:
template: the template to validate.
allowed_placeholders: the list of allowed placeholders.
Returns:
A tuple containing a boolean indicating if the template is valid and a set
with the invalid placeholders.
"""
placeholders = set(TEMPLATE_PLACEHOLDERS_REGEX.findall(template))
allowed_placeholders_set = set(allowed_placeholders)
are_valid = placeholders.issubset(allowed_placeholders_set)
invalid_placeholders = placeholders - allowed_placeholders_set
return are_valid, invalid_placeholders
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
from pydantic import Field
from typing_extensions import override
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.evol_instruct.utils import MUTATION_TEMPLATES
from distilabel.typing import ChatType
from distilabel.utils.lists import flatten_responses
if TYPE_CHECKING:
from distilabel.typing import LLMStatistics, StepOutput
class EvolInstruct(Task):
"""Evolve instructions using an `LLM`.
WizardLM: Empowering Large Language Models to Follow Complex Instructions
Attributes:
num_evolutions: The number of evolutions to be performed.
store_evolutions: Whether to store all the evolutions or just the last one. Defaults
to `False`.
generate_answers: Whether to generate answers for the evolved instructions. Defaults
to `False`.
include_original_instruction: Whether to include the original instruction in the
`evolved_instructions` output column. Defaults to `False`.
mutation_templates: The mutation templates to be used for evolving the instructions.
Defaults to the ones provided in the `utils.py` file.
seed: The seed to be set for `numpy` in order to randomly pick a mutation method.
Defaults to `42`.
Runtime parameters:
- `seed`: The seed to be set for `numpy` in order to randomly pick a mutation method.
Input columns:
- instruction (`str`): The instruction to evolve.
Output columns:
- evolved_instruction (`str`): The evolved instruction if `store_evolutions=False`.
- evolved_instructions (`List[str]`): The evolved instructions if `store_evolutions=True`.
- model_name (`str`): The name of the LLM used to evolve the instructions.
- answer (`str`): The answer to the evolved instruction if `generate_answers=True`
and `store_evolutions=False`.
- answers (`List[str]`): The answers to the evolved instructions if `generate_answers=True`
and `store_evolutions=True`.
Categories:
- evol
- instruction
References:
- [WizardLM: Empowering Large Language Models to Follow Complex Instructions](https://arxiv.org/abs/2304.12244)
- [GitHub: h2oai/h2o-wizardlm](https://github.com/h2oai/h2o-wizardlm)
Examples:
Evolve an instruction using an LLM:
```python
from distilabel.steps.tasks import EvolInstruct
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
num_evolutions=2,
)
evol_instruct.load()
result = next(evol_instruct.process([{"instruction": "common instruction"}]))
# result
# [{'instruction': 'common instruction', 'evolved_instruction': 'evolved instruction', 'model_name': 'model_name'}]
```
Keep the iterations of the evolutions:
```python
from distilabel.steps.tasks import EvolInstruct
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
num_evolutions=2,
store_evolutions=True,
)
evol_instruct.load()
result = next(evol_instruct.process([{"instruction": "common instruction"}]))
# result
# [
# {
# 'instruction': 'common instruction',
# 'evolved_instructions': ['initial evolution', 'final evolution'],
# 'model_name': 'model_name'
# }
# ]
```
Generate answers for the instructions in a single step:
```python
from distilabel.steps.tasks import EvolInstruct
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
num_evolutions=2,
generate_answers=True,
)
evol_instruct.load()
result = next(evol_instruct.process([{"instruction": "common instruction"}]))
# result
# [
# {
# 'instruction': 'common instruction',
# 'evolved_instruction': 'evolved instruction',
# 'answer': 'answer to the instruction',
# 'model_name': 'model_name'
# }
# ]
```
Citations:
```
@misc{xu2023wizardlmempoweringlargelanguage,
title={WizardLM: Empowering Large Language Models to Follow Complex Instructions},
author={Can Xu and Qingfeng Sun and Kai Zheng and Xiubo Geng and Pu Zhao and Jiazhan Feng and Chongyang Tao and Daxin Jiang},
year={2023},
eprint={2304.12244},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2304.12244},
}
```
"""
num_evolutions: int
store_evolutions: bool = False
generate_answers: bool = False
include_original_instruction: bool = False
mutation_templates: Dict[str, str] = MUTATION_TEMPLATES
seed: RuntimeParameter[int] = Field(
default=42,
description="As `numpy` is being used in order to randomly pick a mutation method, then is nice to seed a random seed.",
)
@property
def inputs(self) -> List[str]:
"""The input for the task is the `instruction`."""
return ["instruction"]
def format_input(self, input: str) -> ChatType: # type: ignore
"""The input is formatted as a `ChatType` assuming that the instruction
is the first interaction from the user within a conversation. And the
`system_prompt` is added as the first message if it exists."""
return [{"role": "user", "content": input}]
@property
def outputs(self) -> List[str]:
"""The output for the task are the `evolved_instruction/s`, the `answer` if `generate_answers=True`
and the `model_name`."""
# TODO: having to define a `model_name` column every time as the `Task.outputs` is not ideal,
# this could be handled always and the value could be included within the DAG validation when
# a `Task` is used, since all the `Task` subclasses will have an `llm` with a `model_name` attr.
_outputs = [
(
"evolved_instruction"
if not self.store_evolutions
else "evolved_instructions"
),
"model_name",
]
if self.generate_answers:
_outputs.append("answer" if not self.store_evolutions else "answers")
return _outputs
@override
def format_output( # type: ignore
self, instructions: Union[str, List[str]], answers: Optional[List[str]] = None
) -> Dict[str, Any]: # type: ignore
"""The output for the task is a dict with: `evolved_instruction` or `evolved_instructions`,
depending whether the value is either `False` or `True` for `store_evolutions`, respectively;
`answer` if `generate_answers=True`; and, finally, the `model_name`.
Args:
instructions: The instructions to be included within the output.
answers: The answers to be included within the output if `generate_answers=True`.
Returns:
If `store_evolutions=False` and `generate_answers=True` return {"evolved_instruction": ..., "model_name": ..., "answer": ...};
if `store_evolutions=True` and `generate_answers=True` return {"evolved_instructions": ..., "model_name": ..., "answer": ...};
if `store_evolutions=False` and `generate_answers=False` return {"evolved_instruction": ..., "model_name": ...};
if `store_evolutions=True` and `generate_answers=False` return {"evolved_instructions": ..., "model_name": ...}.
"""
_output = {}
if not self.store_evolutions:
_output["evolved_instruction"] = instructions[-1]
else:
_output["evolved_instructions"] = instructions
if self.generate_answers and answers:
if not self.store_evolutions:
_output["answer"] = answers[-1]
else:
_output["answers"] = answers
_output["model_name"] = self.llm.model_name
return _output
@property
def mutation_templates_names(self) -> List[str]:
"""Returns the names i.e. keys of the provided `mutation_templates`."""
return list(self.mutation_templates.keys())
def _apply_random_mutation(self, instruction: str) -> str:
"""Applies a random mutation from the ones provided as part of the `mutation_templates`
enum, and returns the provided instruction within the mutation prompt.
Args:
instruction: The instruction to be included within the mutation prompt.
Returns:
A random mutation prompt with the provided instruction.
"""
mutation = np.random.choice(self.mutation_templates_names)
return self.mutation_templates[mutation].replace("<PROMPT>", instruction) # type: ignore
def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]:
"""Evolves the instructions provided as part of the inputs of the task.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Returns:
A list where each item is a list with either the last evolved instruction if
`store_evolutions=False` or all the evolved instructions if `store_evolutions=True`.
"""
instructions: List[List[str]] = [[input["instruction"]] for input in inputs]
statistics: "LLMStatistics" = defaultdict(list)
for iter_no in range(self.num_evolutions):
formatted_prompts = []
for instruction in instructions:
formatted_prompts.append(self._apply_random_mutation(instruction[-1]))
formatted_prompts = [
self.format_input(prompt) for prompt in formatted_prompts
]
responses = self.llm.generate(
formatted_prompts,
**self.llm.generation_kwargs, # type: ignore
)
generated_prompts = flatten_responses(
[response["generations"] for response in responses]
)
for response in responses:
for k, v in response["statistics"].items():
statistics[k].append(v[0])
evolved_instructions = []
for generated_prompt in generated_prompts:
generated_prompt = generated_prompt.split("Prompt#:")[-1].strip()
evolved_instructions.append(generated_prompt)
if self.store_evolutions:
instructions = [
instruction + [evolved_instruction]
for instruction, evolved_instruction in zip(
instructions, evolved_instructions
)
]
else:
instructions = [
[evolved_instruction]
for evolved_instruction in evolved_instructions
]
self._logger.info(
f"🔄 Ran iteration {iter_no} evolving {len(instructions)} instructions!"
)
return instructions, dict(statistics)
def _generate_answers(
self, evolved_instructions: List[List[str]]
) -> Tuple[List[List[str]], "LLMStatistics"]:
"""Generates the answer for the instructions in `instructions`.
Args:
evolved_instructions: A list of lists where each item is a list with either the last
evolved instruction if `store_evolutions=False` or all the evolved instructions
if `store_evolutions=True`.
Returns:
A list of answers for each instruction.
"""
formatted_instructions = [
self.format_input(instruction)
for instructions in evolved_instructions
for instruction in instructions
]
responses = self.llm.generate(
formatted_instructions,
num_generations=1,
**self.llm.generation_kwargs, # type: ignore
)
generations = [response["generations"] for response in responses]
statistics: Dict[str, Any] = defaultdict(list)
for response in responses:
for k, v in response["statistics"].items():
statistics[k].append(v[0])
step = (
self.num_evolutions
if not self.include_original_instruction
else self.num_evolutions + 1
)
return [
flatten_responses(generations[i : i + step])
for i in range(0, len(responses), step)
], dict(statistics)
@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Processes the inputs of the task and generates the outputs using the LLM.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Yields:
A list of Python dictionaries with the outputs of the task.
"""
evolved_instructions, statistics = self._evolve_instructions(inputs)
if self.store_evolutions:
# Remove the input instruction from the `evolved_instructions` list
from_ = 1 if not self.include_original_instruction else 0
evolved_instructions = [
instruction[from_:] for instruction in evolved_instructions
]
if not self.generate_answers:
for input, instruction in zip(inputs, evolved_instructions):
input.update(self.format_output(instruction))
input.update(
{
"distilabel_metadata": {
f"statistics_instruction_{self.name}": statistics
}
}
)
yield inputs
self._logger.info(
f"🎉 Finished evolving {len(evolved_instructions)} instructions!"
)
if self.generate_answers:
self._logger.info(
f"🧠 Generating answers for the {len(evolved_instructions)} evolved instructions!"
)
answers, statistics = self._generate_answers(evolved_instructions)
self._logger.info(
f"🎉 Finished generating answers for the {len(evolved_instructions)} evolved"
" instructions!"
)
for idx, (input, instruction) in enumerate(
zip(inputs, evolved_instructions)
):
input.update(self.format_output(instruction, answers[idx]))
input.update(
{
"distilabel_metadata": {
f"statistics_answer_{self.name}": statistics
}
}
)
yield inputs
@override
def _sample_input(self) -> ChatType:
return self.format_input(
self._apply_random_mutation("<PLACEHOLDER_INSTRUCTION>")
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment