Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# - Try to import the function from a given module
# - If function, try to import it and run it
# - If fails, track the error message, and return it
import inspect
import json
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Union
from pydantic import Field, PrivateAttr
from typing_extensions import override
from distilabel.steps.base import Step, StepInput
from distilabel.steps.tasks.apigen.utils import (
execute_from_response,
load_module_from_path,
)
if TYPE_CHECKING:
from types import ModuleType
from distilabel.typing import StepColumns, StepOutput
class APIGenExecutionChecker(Step):
"""Executes the generated function calls.
This step checks if a given answer from a model as generated by `APIGenGenerator`
can be executed against the given library (given by `libpath`, which is a string
pointing to a python .py file with functions).
Attributes:
libpath: The path to the library where we will retrieve the functions.
It can also point to a folder with the functions. In this case, the folder
layout should be a folder with .py files, each containing a single function,
the name of the function being the same as the filename.
check_is_dangerous: Bool to exclude some potentially dangerous functions, it contains
some heuristics found while testing. This functions can run subprocesses, deal with
the OS, or have other potentially dangerous operations. Defaults to True.
Input columns:
- answers (`str`): List with arguments to be passed to the function,
dumped as a string from a list of dictionaries. Should be loaded using
`json.loads`.
Output columns:
- keep_row_after_execution_check (`bool`): Whether the function should be kept or not.
- execution_result (`str`): The result from executing the function.
Categories:
- filtering
- execution
References:
- [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
- [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
Examples:
Execute a function from a given library with the answer from an LLM:
```python
from distilabel.steps.tasks import APIGenExecutionChecker
# For the libpath you can use as an example the file at the tests folder:
# ../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py
task = APIGenExecutionChecker(
libpath="../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py",
)
task.load()
res = next(
task.process(
[
{
"answers": [
{
"arguments": {
"initial_velocity": 0.2,
"acceleration": 0.1,
"time": 0.5,
},
"name": "final_velocity",
}
],
}
]
)
)
res
#[{'answers': [{'arguments': {'initial_velocity': 0.2, 'acceleration': 0.1, 'time': 0.5}, 'name': 'final_velocity'}], 'keep_row_after_execution_check': True, 'execution_result': ['0.25']}]
```
"""
libpath: str = Field(
default=...,
description=(
"The path to the library where we will retrieve the functions, "
"or a folder with python files named the same as the functions they contain.",
),
)
check_is_dangerous: bool = Field(
default=True,
description=(
"Bool to exclude some potentially dangerous functions, it contains "
"some heuristics found while testing. This functions can run subprocesses, "
"deal with the OS, or have other potentially dangerous operations.",
),
)
_toolbox: Union["ModuleType", None] = PrivateAttr(None)
def load(self) -> None:
"""Loads the library where the functions will be extracted from."""
super().load()
if Path(self.libpath).suffix == ".py":
self._toolbox = load_module_from_path(self.libpath)
def unload(self) -> None:
self._toolbox = None
@property
def inputs(self) -> "StepColumns":
"""The inputs for the task are those found in the original dataset."""
return ["answers"]
@property
def outputs(self) -> "StepColumns":
"""The outputs are the columns required by `APIGenGenerator` task."""
return ["keep_row_after_execution_check", "execution_result"]
def _get_function(self, function_name: str) -> Callable:
"""Retrieves the function from the toolbox.
Args:
function_name: The name of the function to retrieve.
Returns:
Callable: The function to be executed.
"""
if self._toolbox:
return getattr(self._toolbox, function_name, None)
try:
toolbox = load_module_from_path(
str(Path(self.libpath) / f"{function_name}.py")
)
return getattr(toolbox, function_name, None)
except FileNotFoundError:
return None
except Exception as e:
self._logger.warning(f"Error loading function '{function_name}': {e}")
return None
def _is_dangerous(self, function: Callable) -> bool:
"""Checks if a function is dangerous to remove it.
Contains a list of heuristics to avoid executing possibly dangerous functions.
"""
source_code = inspect.getsource(function)
# We don't want to execute functions that use subprocess
if (
("subprocess." in source_code)
or ("os.system(" in source_code)
or ("input(" in source_code)
# Avoiding threading
or ("threading.Thread(" in source_code)
or ("exec(" in source_code)
# Avoiding argparse (not sure why)
or ("argparse.ArgumentParser(" in source_code)
# Avoiding logging changing the levels to not mess with the logs
or (".setLevel(" in source_code)
# Don't run a test battery
or ("unittest.main(" in source_code)
# Avoid exiting the program
or ("sys.exit(" in source_code)
or ("exit(" in source_code)
or ("raise SystemExit(" in source_code)
or ("multiprocessing.Pool(" in source_code)
):
return True
return False
@override
def process(self, inputs: StepInput) -> "StepOutput":
"""Checks the answer to see if it can be executed.
Captures the possible errors and returns them.
If a single example is provided, it is copied to avoid raising an error.
Args:
inputs: A list of dictionaries with the input data.
Yields:
A list of dictionaries with the output data.
"""
for input in inputs:
output = []
if input["answers"]:
answers = json.loads(input["answers"])
else:
input.update(
**{
"keep_row_after_execution_check": False,
"execution_result": ["No answers were provided."],
}
)
continue
for answer in answers:
if answer is None:
output.append(
{
"keep": False,
"execution_result": "Nothing was generated for this answer.",
}
)
continue
function_name = answer.get("name", None)
arguments = answer.get("arguments", None)
self._logger.debug(
f"Executing function '{function_name}' with arguments: {arguments}"
)
function = self._get_function(function_name)
if self.check_is_dangerous:
if function and self._is_dangerous(function):
function = None
if function is None:
output.append(
{
"keep": False,
"execution_result": f"Function '{function_name}' not found.",
}
)
else:
execution = execute_from_response(function, arguments)
output.append(
{
"keep": execution["keep"],
"execution_result": execution["execution_result"],
}
)
# We only consider a good response if all the answers were executed successfully,
# but keep the reasons for further review if needed.
input.update(
**{
"keep_row_after_execution_check": all(
o["keep"] is True for o in output
),
"execution_result": [o["execution_result"] for o in output],
}
)
yield inputs
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.resources as importlib_resources
import json
import random
from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Union
import orjson
from jinja2 import Template
from pydantic import PrivateAttr
from typing_extensions import override
from distilabel.steps.tasks.apigen.utils import remove_fences
from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
from distilabel.typing import ChatType, StepColumns
SYSTEM_PROMPT_API_GEN: Final[str] = """\
You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.
Construct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.
Ensure the query:
- Is clear and concise
- Demonstrates typical use cases
- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words
- Across a variety level of difficulties, ranging from beginner and advanced use cases
- The corresponding result's parameter types and ranges match with the function's descriptions
Ensure the answer:
- Is a list of function calls in JSON format
- The length of the answer list should be equal to the number of requests in the query
- Can solve all the requests in the query effectively"""
class APIGenGenerator(Task):
"""Generate queries and answers for the given functions in JSON format.
The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
verifiable and diverse function-calling datasets. The task generates a set of diverse queries
and corresponding answers for the given functions in JSON format.
Attributes:
system_prompt: The system prompt to guide the user in the generation of queries and answers.
use_tools: Whether to use the tools available in the prompt to generate the queries and answers.
In case the tools are given in the input, they will be added to the prompt.
number: The number of queries to generate. It can be a list, where each number will be
chosen randomly, or a dictionary with the number of queries and the probability of each.
I.e: `number=1`, `number=[1, 2, 3]`, `number={1: 0.5, 2: 0.3, 3: 0.2}` are all valid inputs.
It corresponds to the number of parallel queries to generate.
use_default_structured_output: Whether to use the default structured output or not.
Input columns:
- examples (`str`): Examples used as few shots to guide the model.
- func_name (`str`): Name for the function to generate.
- func_desc (`str`): Description of what the function should do.
- tools (`str`): JSON formatted string containing the tool representation of the function.
Output columns:
- query (`str`): The list of queries.
- answers (`str`): JSON formatted string with the list of answers, containing the info as
a dictionary to be passed to the functions.
Categories:
- text-generation
References:
- [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
- [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
Examples:
Generate without structured output (original implementation):
```python
from distilabel.steps.tasks import ApiGenGenerator
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 1024,
},
)
apigen = ApiGenGenerator(
use_default_structured_output=False,
llm=llm
)
apigen.load()
res = next(
apigen.process(
[
{
"examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
"func_name": "getrandommovie",
"func_desc": "Returns a list of random movies from a database by calling an external API."
}
]
)
)
res
# [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
# 'number': 1,
# 'func_name': 'getrandommovie',
# 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
# 'queries': ['I want to watch a movie tonight, can you recommend a random one from your database?',
# 'Give me 5 random movie suggestions from your database to plan my weekend.'],
# 'answers': [[{'name': 'getrandommovie', 'arguments': {}}],
# [{'name': 'getrandommovie', 'arguments': {}},
# {'name': 'getrandommovie', 'arguments': {}},
# {'name': 'getrandommovie', 'arguments': {}},
# {'name': 'getrandommovie', 'arguments': {}},
# {'name': 'getrandommovie', 'arguments': {}}]],
# 'raw_input_api_gen_generator_0': [{'role': 'system',
# 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
# {'role': 'user',
# 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n {\n "query": "The generated query.",\n "answers": [\n {\n "name": "api_name",\n "arguments": {\n "arg_name": "value"\n ... (more arguments as required)\n }\n },\n ... (more API calls as required)\n ]\n }\n]\n```\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```
Generate with structured output:
```python
from distilabel.steps.tasks import ApiGenGenerator
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 1024,
},
)
apigen = ApiGenGenerator(
use_default_structured_output=True,
llm=llm
)
apigen.load()
res_struct = next(
apigen.process(
[
{
"examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
"func_name": "getrandommovie",
"func_desc": "Returns a list of random movies from a database by calling an external API."
}
]
)
)
res_struct
# [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
# 'number': 1,
# 'func_name': 'getrandommovie',
# 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
# 'queries': ["I'm bored and want to watch a movie. Can you suggest some movies?",
# "My family and I are planning a movie night. We can't decide on what to watch. Can you suggest some random movie titles?"],
# 'answers': [[{'arguments': {}, 'name': 'getrandommovie'}],
# [{'arguments': {}, 'name': 'getrandommovie'}]],
# 'raw_input_api_gen_generator_0': [{'role': 'system',
# 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
# {'role': 'user',
# 'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```
"""
system_prompt: str = SYSTEM_PROMPT_API_GEN
use_default_structured_output: bool = False
number: Union[int, List[int], Dict[int, float]] = 1
use_tools: bool = True
_number: Union[int, None] = PrivateAttr(None)
_fn_parallel_queries: Union[Callable[[], str], None] = PrivateAttr(None)
_format_inst: Union[str, None] = PrivateAttr(None)
def load(self) -> None:
"""Loads the template for the generator prompt."""
super().load()
_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "apigen"
/ "generator.jinja2"
)
self._template = Template(open(_path).read())
self._format_inst = self._set_format_inst()
def _parallel_queries(self, number: int) -> Callable[[int], str]:
"""Prepares the function to update the parallel queries guide in the prompt.
Raises:
ValueError: if `is_parallel` is not a boolean or a list of floats.
Returns:
The function to generate the parallel queries guide.
"""
if number > 1:
return (
"It can contain multiple parallel queries in natural language for the given functions. "
"They could use either the same function with different arguments or different functions.\n"
)
return ""
def _get_number(self) -> int:
"""Generates the number of queries to generate in a single call.
The number must be set to `_number` to avoid changing the original value
when calling `_default_error`.
"""
if isinstance(self.number, list):
self._number = random.choice(self.number)
elif isinstance(self.number, dict):
self._number = random.choices(
list(self.number.keys()), list(self.number.values())
)[0]
else:
self._number = self.number
return self._number
def _set_format_inst(self) -> str:
"""Prepares the function to generate the formatted instructions for the prompt.
If the default structured output is used, returns an empty string because nothing
else is needed, otherwise, returns the original addition to the prompt to guide the model
to generate a formatted JSON.
"""
return (
"\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n"
"```\n"
"[\n"
" {\n"
' "query": "The generated query.",\n'
' "answers": [\n'
" {\n"
' "name": "api_name",\n'
' "arguments": {\n'
' "arg_name": "value"\n'
" ... (more arguments as required)\n"
" }\n"
" },\n"
" ... (more API calls as required)\n"
" ]\n"
" }\n"
"]\n"
"```\n"
)
def _get_func_desc(self, input: Dict[str, Any]) -> str:
"""If available and required, will use the info from the tools in the
prompt for extra information. Otherwise will use jut the function description.
"""
if not self.use_tools:
return input["func_desc"]
extra = "" # Extra information from the tools (if available will be added)
if "tools" in input:
extra = f"\n\nThis is the available tool to guide you (respect the order of the parameters):\n{input['tools']}"
return input["func_desc"] + extra
@property
def inputs(self) -> "StepColumns":
"""The inputs for the task."""
return {
"examples": True,
"func_name": True,
"func_desc": True,
"tools": False,
}
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType`."""
number = self._get_number()
parallel_queries = self._parallel_queries(number)
return [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": self._template.render(
examples=input["examples"],
parallel_queries=parallel_queries,
number=number,
func_name=input["func_name"],
func_desc=self._get_func_desc(input),
format_inst=self._format_inst,
),
},
]
@property
def outputs(self) -> "StepColumns":
"""The output for the task are the queries and corresponding answers."""
return ["query", "answers", "model_name"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
"""The output is formatted as a list with the score of each instruction.
Args:
output: the raw output of the LLM.
input: the input to the task. Used for obtaining the number of responses.
Returns:
A dict with the queries and answers pairs.
The answers are an array of answers corresponding to the query.
Each answer is represented as an object with the following properties:
- name (string): The name of the tool used to generate the answer.
- arguments (object): An object representing the arguments passed to the tool to generate the answer.
Each argument is represented as a key-value pair, where the key is the parameter name and the
value is the corresponding value.
"""
if output is None:
return self._default_error(input)
if not self.use_default_structured_output:
output = remove_fences(output)
try:
pairs = orjson.loads(output)
except orjson.JSONDecodeError:
return self._default_error(input)
pairs = pairs["pairs"] if self.use_default_structured_output else pairs
return self._format_output(pairs, input)
def _format_output(
self, pairs: Dict[str, Any], input: Dict[str, Any]
) -> Dict[str, Any]:
"""Parses the response, returning a dictionary with queries and answers.
Args:
pairs: The parsed dictionary from the LLM's output.
input: The input from the `LLM`.
Returns:
Formatted output, where the `queries` are a list of strings, and the `answers`
are a list of objects.
"""
try:
input.update(
**{
"query": pairs[0]["query"],
"answers": json.dumps(pairs[0]["answers"]),
}
)
return input
except Exception as e:
self._logger.error(f"Error formatting output: {e}, pairs: '{pairs}'")
return self._default_error(input)
def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Returns a default error output, to fill the responses in case of failure."""
input.update(
**{
"query": None,
"answers": json.dumps([None] * self._number),
}
)
return input
@override
def get_structured_output(self) -> Dict[str, Any]:
"""Creates the json schema to be passed to the LLM, to enforce generating
a dictionary with the output which can be directly parsed as a python dictionary.
The schema corresponds to the following:
```python
from typing import Dict, List
from pydantic import BaseModel
class Answer(BaseModel):
name: str
arguments: Dict[str, str]
class QueryAnswer(BaseModel):
query: str
answers: List[Answer]
class QueryAnswerPairs(BaseModel):
pairs: List[QueryAnswer]
json.dumps(QueryAnswerPairs.model_json_schema(), indent=4)
```
Returns:
JSON Schema of the response to enforce.
"""
return {
"$defs": {
"Answer": {
"properties": {
"name": {"title": "Name", "type": "string"},
"arguments": {
"additionalProperties": {"type": "string"},
"title": "Arguments",
"type": "object",
},
},
"required": ["name", "arguments"],
"title": "Answer",
"type": "object",
},
"QueryAnswer": {
"properties": {
"query": {"title": "Query", "type": "string"},
"answers": {
"items": {"$ref": "#/$defs/Answer"},
"title": "Answers",
"type": "array",
},
},
"required": ["query", "answers"],
"title": "QueryAnswer",
"type": "object",
},
},
"properties": {
"pairs": {
"items": {"$ref": "#/$defs/QueryAnswer"},
"title": "Pairs",
"type": "array",
}
},
"required": ["pairs"],
"title": "QueryAnswerPairs",
"type": "object",
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.resources as importlib_resources
from typing import TYPE_CHECKING, Any, Dict, Final, Union
import orjson
from jinja2 import Template
from pydantic import PrivateAttr
from typing_extensions import override
from distilabel.steps.tasks.apigen.utils import remove_fences
from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
from distilabel.typing import ChatType, StepColumns
SYSTEM_PROMPT_SEMANTIC_CHECKER: Final[str] = """\
As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.
These function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.
Do not pass if:
1. The function call does not align with the query’s objective, or the input arguments appear incorrect.
2. The function call and arguments are not properly chosen from the available functions.
3. The number of function calls does not correspond to the user’s intentions.
4. The execution results are irrelevant and do not match the function’s purpose.
5. The execution results contain errors or reflect that the function calls were not executed successfully.
""".rstrip()
class APIGenSemanticChecker(Task):
r"""Generate queries and answers for the given functions in JSON format.
The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
verifiable and diverse function-calling datasets. The task generates a set of diverse queries
and corresponding answers for the given functions in JSON format.
Attributes:
system_prompt: System prompt for the task. Has a default one.
exclude_failed_execution: Whether to exclude failed executions (won't run on those
rows that have a False in `keep_row_after_execution_check` column, which
comes from running `APIGenExecutionChecker`). Defaults to True.
Input columns:
- func_desc (`str`): Description of what the function should do.
- query (`str`): Instruction from the user.
- answers (`str`): JSON encoded list with arguments to be passed to the function/API.
Should be loaded using `json.loads`.
- execution_result (`str`): Result of the function/API executed.
Output columns:
- thought (`str`): Reasoning for the output on whether to keep this output or not.
- keep_row_after_semantic_check (`bool`): True or False, can be used to filter
afterwards.
Categories:
- filtering
- text-generation
References:
- [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
- [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)
Examples:
Semantic checker for generated function calls (original implementation):
```python
from distilabel.steps.tasks import APIGenSemanticChecker
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 1024,
},
)
semantic_checker = APIGenSemanticChecker(
use_default_structured_output=False,
llm=llm
)
semantic_checker.load()
res = next(
semantic_checker.process(
[
{
"func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
"query": "What information can be obtained about the Maine Coon cat breed?",
"answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
"execution_result": "The Maine Coon is a big and hairy breed of cat",
}
]
)
)
res
# [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
# 'query': 'What information can be obtained about the Maine Coon cat breed?',
# 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
# 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
# 'thought': '',
# 'keep_row_after_semantic_check': True,
# 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
# 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
# {'role': 'user',
# 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n "thought": "Concisely describe your reasoning here",\n "pass": "yes" or "no"\n}\n```\n'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```
Semantic checker for generated function calls (structured output):
```python
from distilabel.steps.tasks import APIGenSemanticChecker
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 1024,
},
)
semantic_checker = APIGenSemanticChecker(
use_default_structured_output=True,
llm=llm
)
semantic_checker.load()
res = next(
semantic_checker.process(
[
{
"func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
"query": "What information can be obtained about the Maine Coon cat breed?",
"answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
"execution_result": "The Maine Coon is a big and hairy breed of cat",
}
]
)
)
res
# [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
# 'query': 'What information can be obtained about the Maine Coon cat breed?',
# 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
# 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
# 'keep_row_after_semantic_check': True,
# 'thought': '',
# 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
# 'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
# {'role': 'user',
# 'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```
"""
system_prompt: str = SYSTEM_PROMPT_SEMANTIC_CHECKER
use_default_structured_output: bool = False
_format_inst: Union[str, None] = PrivateAttr(None)
def load(self) -> None:
"""Loads the template for the generator prompt."""
super().load()
_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "apigen"
/ "semantic_checker.jinja2"
)
self._template = Template(open(_path).read())
self._format_inst = self._set_format_inst()
def _set_format_inst(self) -> str:
"""Prepares the function to generate the formatted instructions for the prompt.
If the default structured output is used, returns an empty string because nothing
else is needed, otherwise, returns the original addition to the prompt to guide the model
to generate a formatted JSON.
"""
return (
"\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n"
"```\n"
"{\n"
' "thought": "Concisely describe your reasoning here",\n'
' "passes": "yes" or "no"\n'
"}\n"
"```\n"
)
@property
def inputs(self) -> "StepColumns":
"""The inputs for the task."""
return {
"func_desc": True,
"query": True,
"answers": True,
"execution_result": True,
"keep_row_after_execution_check": True,
}
def format_input(self, input: Dict[str, Any]) -> "ChatType":
"""The input is formatted as a `ChatType`."""
return [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": self._template.render(
func_desc=input["func_desc"],
query=input["query"] or "",
func_call=input["answers"] or "",
execution_result=input["execution_result"],
format_inst=self._format_inst,
),
},
]
@property
def outputs(self) -> "StepColumns":
"""The output for the task are the queries and corresponding answers."""
return ["keep_row_after_semantic_check", "thought"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
"""The output is formatted as a list with the score of each instruction.
Args:
output: the raw output of the LLM.
input: the input to the task. Used for obtaining the number of responses.
Returns:
A dict with the queries and answers pairs.
The answers are an array of answers corresponding to the query.
Each answer is represented as an object with the following properties:
- name (string): The name of the tool used to generate the answer.
- arguments (object): An object representing the arguments passed to the tool to generate the answer.
Each argument is represented as a key-value pair, where the key is the parameter name and the
value is the corresponding value.
"""
if output is None:
return self._default_error(input)
output = remove_fences(output)
try:
result = orjson.loads(output)
# Update the column name and change to bool
result["keep_row_after_semantic_check"] = (
result.pop("passes").lower() == "yes"
)
input.update(**result)
return input
except orjson.JSONDecodeError:
return self._default_error(input)
def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Default error message for the task."""
input.update({"thought": None, "keep_row_after_semantic_check": None})
return input
@override
def get_structured_output(self) -> Dict[str, Any]:
"""Creates the json schema to be passed to the LLM, to enforce generating
a dictionary with the output which can be directly parsed as a python dictionary.
The schema corresponds to the following:
```python
from typing import Literal
from pydantic import BaseModel
import json
class Checker(BaseModel):
thought: str
passes: Literal["yes", "no"]
json.dumps(Checker.model_json_schema(), indent=4)
```
Returns:
JSON Schema of the response to enforce.
"""
return {
"properties": {
"thought": {"title": "Thought", "type": "string"},
"passes": {"enum": ["yes", "no"], "title": "Passes", "type": "string"},
},
"required": ["thought", "passes"],
"title": "Checker",
"type": "object",
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import re
import signal
from typing import TYPE_CHECKING, Any, Callable, Dict, TypedDict, Union
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
from types import ModuleType
from distilabel.typing import StepColumns, StepOutput
class PrepareExamples(Step):
r"""Helper step to create examples from `query` and `answers` pairs used as Few Shots in APIGen.
Attributes:
template (str): The template to format the examples.
Input columns:
- query (`str`): The query to generate examples from.
- answers (`str`): The answers to the query.
Output columns:
- examples (`str`): The formatted examples.
Categories:
- format
Examples:
Generate examples for APIGen:
```python
from distilabel.steps.tasks.apigen.utils import PrepareExamples
prepare_examples = PrepareExamples()
result = next(prepare_examples.process(
[
{
"query": ['I need the area of circles with radius 2.5, 5, and 7.5 inches, please.', 'Can you provide the current locations of buses and trolleys on route 12?'],
"answers": ['[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]', '[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]']
}
]
)
# result
# [{'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}, {'examples': '## Query:\nI need the area of circles with radius 2.5, 5, and 7.5 inches, please.\n## Answers:\n[{"name": "circle_area", "arguments": {"radius": 2.5}}, {"name": "circle_area", "arguments": {"radius": 5}}, {"name": "circle_area", "arguments": {"radius": 7.5}}]\n\n## Query:\nCan you provide the current locations of buses and trolleys on route 12?\n## Answers:\n[{"name": "bus_trolley_locations", "arguments": {"route": "12"}}]'}]
```
"""
template: str = "## Query:\n{query}\n## Answers:\n{answers}"
@property
def inputs(self) -> "StepColumns":
return ["query", "answers"]
@property
def outputs(self) -> "StepColumns":
return ["examples"]
def process(self, inputs: StepInput) -> "StepOutput":
"""The process prepares the data for the `APIGenGenerator` task.
If a single example is provided, it is copied to avoid raising an error.
Args:
inputs: A list of dictionaries with the input data.
Yields:
A list of dictionaries with the output data.
"""
outputs = []
for input in inputs:
example_list = []
for query, answers in zip(input["query"], input["answers"]):
example_list.append(self.template.format(query=query, answers=answers))
outputs.append({"examples": "\n\n".join(example_list)})
yield outputs
def load_module_from_path(path: str) -> "ModuleType":
"""Loads a python module from a given path.
Args:
path: Path pointing to the module.
Returns:
ModuleType
Example:
```python
path = "/path/to/module.py"
module = load_module_from_path(path)
# And you can load functions from the module like this:
function = getattr(module, "function_name")
function(*args, **kwargs)
```
"""
spec = importlib.util.spec_from_file_location("module.name", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class FunctionResult(TypedDict):
keep: bool
execution_result: str
def execute_from_response(
function: Callable, call_answer: Union[Dict[str, Any], None]
) -> FunctionResult:
"""Executes a function with the given arguments as generated by `APIGenGenerator`.
Given that we cannot cast all the arguments arbitrarily, we try to evaluate them,
which ensures the strings can be converted to the correct type if possible (say
a list of lists of ints will be passed as such instead of its string representation).
Args:
function: A callable object.
call_answer: The arguments to call the function, as generated by the model.
Returns:
A container with the result of the execution and if the row should be kept.
"""
if not function:
return FunctionResult(keep=False, execution_result="Function not found")
if call_answer:
for key, value in call_answer.items():
if isinstance(value, str):
try:
call_answer[key] = eval(value)
except Exception:
# Leave as is and expect the function to handle it
pass
try:
if call_answer:
result = run_function_with_timeout(function, 5, *call_answer.values())
else:
# There can be functions that do not require arguments
result = run_function_with_timeout(function, 5)
return FunctionResult(keep=True, execution_result=str(result))
except Exception as e:
return FunctionResult(keep=False, execution_result=str(e))
def remove_json_fences(text: str) -> str:
pattern = r"^```json\n([\s\S]*)\n```$"
match = re.match(pattern, text, re.MULTILINE)
if match:
return match.group(1)
return text
def remove_fences(text: str) -> str:
pattern = r"^```\n([\s\S]*)\n```$"
match = re.match(pattern, text, re.MULTILINE)
if match:
return match.group(1)
return text
def timeout_handler(signum, frame):
raise TimeoutError("Function execution timed out")
def run_function_with_timeout(function: Callable, timeout: int = 5, *args: Any) -> Any:
"""Run a function with a timeout, to limit the total time waiting for a result."""
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = function(*args)
finally:
# Cancel the alarm
signal.alarm(0)
return result
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import orjson as json
from jinja2 import Template
from pydantic import BaseModel, Field, PrivateAttr
from typing_extensions import override
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import Task
if sys.version_info < (3, 9):
import importlib_resources
else:
import importlib.resources as importlib_resources
if TYPE_CHECKING:
from argilla import (
LabelQuestion,
MultiLabelQuestion,
RatingQuestion,
Record,
TextField,
TextQuestion,
)
from distilabel.typing import ChatType, StepOutput
class ArgillaLabeller(Task):
"""
Annotate Argilla records based on input fields, example records and question settings.
This task is designed to facilitate the annotation of Argilla records by leveraging a pre-trained LLM.
It uses a system prompt that guides the LLM to understand the input fields, the question type,
and the question settings. The task then formats the input data and generates a response based on the question.
The response is validated against the question's value model, and the final suggestion is prepared for annotation.
Attributes:
_template: a Jinja2 template used to format the input for the LLM.
Input columns:
- record (`argilla.Record`): The record to be annotated.
- fields (`Optional[List[Dict[str, Any]]]`): The list of field settings for the input fields.
- question (`Optional[Dict[str, Any]]`): The question settings for the question to be answered.
- example_records (`Optional[List[Dict[str, Any]]]`): The few shot example records with responses to be used to answer the question.
- guidelines (`Optional[str]`): The guidelines for the annotation task.
Output columns:
- suggestion (`Dict[str, Any]`): The final suggestion for annotation.
Categories:
- text-classification
- scorer
- text-generation
References:
- [`Argilla: Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets`](https://github.com/argilla-io/argilla/)
Examples:
Annotate a record with the same dataset and question:
```python
import argilla as rg
from argilla import Suggestion
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
pending_records_filter = rg.Filter(("status", "==", "pending"))
completed_records_filter = rg.Filter(("status", "==", "completed"))
pending_records = list(
dataset.records(
query=rg.Query(filter=pending_records_filter),
limit=5,
)
)
example_records = list(
dataset.records(
query=rg.Query(filter=completed_records_filter),
limit=5,
)
)
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]
# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
fields=[field],
question=question,
example_records=example_records,
guidelines=dataset.guidelines
)
labeller.load()
# Process the pending records
result = next(
labeller.process(
[
{
"record": record
} for record in pending_records
]
)
)
# Add the suggestions to the records
for record, suggestion in zip(pending_records, result):
record.suggestions.add(Suggestion(**suggestion["suggestion"]))
# Log the updated records
dataset.records.log(pending_records)
```
Annotate a record with alternating datasets and questions:
```python
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]
question2 = dataset.settings.questions["label2"]
# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
)
)
labeller.load()
# Process the record
record = next(dataset.records())
result = next(
labeller.process(
[
{
"record": record,
"fields": [field],
"question": question,
},
{
"record": record,
"fields": [field],
"question": question2,
}
]
)
)
# Add the suggestions to the record
for suggestion in result:
record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))
# Log the updated record
dataset.records.log([record])
```
Overwrite default prompts and instructions:
```python
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.models import InferenceEndpointsLLM
# Overwrite default prompts and instructions
labeller = ArgillaLabeller(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
question_to_label_instruction={
"label_selection": "Select the appropriate label from the list of provided labels.",
"multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
"text": "Provide a text response to the question.",
"rating": "Provide a rating for the question.",
},
)
labeller.load()
```
"""
system_prompt: str = (
"You are an expert annotator and labelling assistant that understands complex domains and natural language processing. "
"You are given input fields and a question. "
"You should create a valid JSON object as an response to the question based on the input fields. "
)
question_to_label_instruction: Dict[str, str] = {
"label_selection": "Select the appropriate label for the fields from the list of optional labels.",
"multi_label_selection": "Select none, one or multiple labels for the fields from the list of optional labels.",
"text": "Provide a response to the question based on the fields.",
"rating": "Provide a rating for the question based on the fields.",
}
example_records: Optional[
RuntimeParameter[Union[List[Union[Dict[str, Any], BaseModel]], None]]
] = Field(
default=None,
description="The few shot serialized example records or `BaseModel`s with responses to be used to answer the question.",
)
fields: Optional[
RuntimeParameter[Union[List[Union[BaseModel, Dict[str, Any]]], None]]
] = Field(
default=None,
description="The field serialized field settings or `BaseModel` for the fields to be used to answer the question.",
)
question: Optional[
RuntimeParameter[
Union[
Dict[str, Any],
BaseModel,
None,
]
]
] = Field(
default=None,
description="The question serialized question settings or `BaseModel` for the question to be answered.",
)
guidelines: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The guidelines for the annotation task.",
)
_template: Union[Template, None] = PrivateAttr(...)
_client: Optional[Any] = PrivateAttr(None)
def load(self) -> None:
"""Loads the Jinja2 template."""
super().load()
_path = str(
importlib_resources.files("distilabel")
/ "steps"
/ "tasks"
/ "templates"
/ "argillalabeller.jinja2"
)
self._template = Template(open(_path).read())
@property
def inputs(self) -> Dict[str, bool]:
return {
"record": True,
"fields": False,
"question": False,
"example_records": False,
"guidelines": False,
}
def _format_record(
self, record: Dict[str, Any], fields: List[Dict[str, Any]]
) -> str:
"""Format the record fields into a string.
Args:
record (Dict[str, Any]): The record to format.
fields (List[Dict[str, Any]]): The fields to format.
Returns:
str: The formatted record fields.
"""
output = []
for field in fields:
output.append(record.get("fields", {}).get(field.get("name", "")))
return "fields: " + "\n".join(output)
def _get_label_instruction(self, question: Dict[str, Any]) -> str:
"""Get the label instruction for the question.
Args:
question (Dict[str, Any]): The question to get the label instruction for.
Returns:
str: The label instruction for the question.
"""
question_type = question["settings"]["type"]
return self.question_to_label_instruction[question_type]
def _format_question(self, question: Dict[str, Any]) -> str:
"""Format the question settings into a string.
Args:
question (Dict[str, Any]): The question to format.
Returns:
str: The formatted question.
"""
output = []
output.append(f"question: {self._get_label_instruction(question)}")
if "options" in question.get("settings", {}):
output.append(
f"optional labels: {[option['value'] for option in question.get('settings', {}).get('options', [])]}"
)
return "\n".join(output)
def _format_example_records(
self,
records: List[Dict[str, Any]],
fields: List[Dict[str, Any]],
question: Dict[str, Any],
) -> str:
"""Format the example records into a string.
Args:
records (List[Dict[str, Any]]): The records to format.
fields (List[Dict[str, Any]]): The fields to format.
question (Dict[str, Any]): The question to format.
Returns:
str: The formatted example records.
"""
base = []
for record in records:
responses = record.get("responses", {})
if responses.get(question["name"]):
base.append(self._format_record(record, fields))
value = responses[question["name"]][0]["value"]
formatted_value = self._assign_value_to_question_value_model(
value, question
)
base.append(f"response: {formatted_value}")
base.append("")
else:
warnings.warn(
f"Record {record} has no response for question {question['name']}. Skipping example record.",
stacklevel=2,
)
return "\n".join(base)
def format_input(
self,
input: Dict[
str,
Union[
Dict[str, Any],
"Record",
"TextField",
"MultiLabelQuestion",
"LabelQuestion",
"RatingQuestion",
"TextQuestion",
],
],
) -> "ChatType":
"""Format the input into a chat message.
Args:
input: The input to format.
Returns:
The formatted chat message.
Raises:
ValueError: If question or fields are not provided.
"""
input_keys = list(self.inputs.keys())
record = input[input_keys[0]]
fields = input.get(input_keys[1], self.fields)
question = input.get(input_keys[2], self.question)
examples = input.get(input_keys[3], self.example_records)
guidelines = input.get(input_keys[4], self.guidelines)
if question is None:
raise ValueError("Question must be provided.")
if fields is None or any(field is None for field in fields):
raise ValueError("Fields must be provided.")
record = record.to_dict() if not isinstance(record, dict) else record
question = question.serialize() if not isinstance(question, dict) else question
fields = [
field.serialize() if not isinstance(field, dict) else field
for field in fields
]
examples = (
[
example.to_dict() if not isinstance(example, dict) else example
for example in examples
]
if examples
else None
)
formatted_fields = self._format_record(record, fields)
formatted_question = self._format_question(question)
formatted_examples = (
self._format_example_records(examples, fields, question)
if examples
else False
)
prompt = self._template.render(
fields=formatted_fields,
question=formatted_question,
examples=formatted_examples,
guidelines=guidelines,
)
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": prompt})
return messages
@property
def outputs(self) -> List[str]:
return ["suggestion"]
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
"""Format the output into a dictionary.
Args:
output (Union[str, None]): The output to format.
input (Dict[str, Any]): The input to format.
Returns:
Dict[str, Any]: The formatted output.
"""
from argilla import Suggestion
question: Union[
Any,
Dict[str, Any],
LabelQuestion,
MultiLabelQuestion,
RatingQuestion,
TextQuestion,
None,
] = input.get(list(self.inputs.keys())[2], self.question) or self.question
question = question.serialize() if not isinstance(question, dict) else question
model = self._get_pydantic_model_of_structured_output(question)
validated_output = model(**json.loads(output))
value = self._get_value_from_question_value_model(validated_output)
suggestion = Suggestion(
value=value,
question_name=question["name"],
type="model",
agent=self.llm.model_name,
).serialize()
return {
self.outputs[0]: {
k: v
for k, v in suggestion.items()
if k in ["value", "question_name", "type", "agent"]
}
}
def _set_llm_structured_output_for_question(self, question: Dict[str, Any]) -> None:
runtime_parameters = self.llm._runtime_parameters
runtime_parameters.update(
{
"structured_output": {
"format": "json",
"schema": self._get_pydantic_model_of_structured_output(question),
},
}
)
self.llm.set_runtime_parameters(runtime_parameters)
@override
def process(self, inputs: StepInput) -> "StepOutput":
"""Process the input through the task.
Args:
inputs (StepInput): The input to process.
Returns:
StepOutput: The output of the task.
"""
question_list = [input.get("question", self.question) for input in inputs]
fields_list = [input.get("fields", self.fields) for input in inputs]
# check if any field for the field in fields is None
for fields in fields_list:
if any(field is None for field in fields):
raise ValueError(
"Fields must be provided during init or through `process` method."
)
# check if any question is None
if any(question is None for question in question_list):
raise ValueError(
"Question must be provided during init or through `process` method."
)
question_list = [
question.serialize() if not isinstance(question, dict) else question
for question in question_list
]
if not all(question == question_list[0] for question in question_list):
warnings.warn(
"Not all questions are the same. Processing each question separately by setting the structured output for each question. This may impact performance.",
stacklevel=2,
)
for input, question in zip(inputs, question_list):
self._set_llm_structured_output_for_question(question)
yield from super().process([input])
else:
question = question_list[0]
self._set_llm_structured_output_for_question(question)
yield from super().process(inputs)
def _get_value_from_question_value_model(
self, question_value_model: BaseModel
) -> Any:
"""Get the value from the question value model.
Args:
question_value_model (BaseModel): The question value model to get the value from.
Returns:
Any: The value from the question value model.
"""
for attr in ["label", "labels", "rating", "text"]:
if hasattr(question_value_model, attr):
return getattr(question_value_model, attr)
raise ValueError(f"Unsupported question type: {question_value_model}")
def _assign_value_to_question_value_model(
self, value: Any, question: Dict[str, Any]
) -> BaseModel:
"""Assign the value to the question value model.
Args:
value (Any): The value to assign.
question (Dict[str, Any]): The question to assign the value to.
Returns:
BaseModel: The question value model with the assigned value.
"""
question_value_model = self._get_pydantic_model_of_structured_output(question)
for attr in ["label", "labels", "rating", "text"]:
try:
model_dict = {attr: value}
question_value_model = question_value_model(**model_dict)
return question_value_model.model_dump_json()
except AttributeError:
pass
return value
def _get_pydantic_model_of_structured_output(
self,
question: Dict[str, Any],
) -> BaseModel:
"""Get the Pydantic model of the structured output.
Args:
question (Dict[str, Any]): The question to get the Pydantic model of the structured output for.
Returns:
BaseModel: The Pydantic model of the structured output.
"""
question_type = question["settings"]["type"]
if question_type == "multi_label_selection":
class QuestionValueModel(BaseModel):
labels: Optional[List[str]] = Field(default_factory=list)
elif question_type == "label_selection":
class QuestionValueModel(BaseModel):
label: str
elif question_type == "text":
class QuestionValueModel(BaseModel):
text: str
elif question_type == "rating":
class QuestionValueModel(BaseModel):
rating: int
else:
raise ValueError(f"Unsupported question type: {question}")
return QuestionValueModel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment