Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, List, Union
from graphviz import Digraph
from pydantic import BaseModel, Field
class Node(BaseModel):
id: int
label: str
color: str
class Edge(BaseModel):
source: int
target: int
label: str
color: str = "black"
class KnowledgeGraph(BaseModel):
nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list)
def visualize_knowledge_graph(kg: KnowledgeGraph):
dot = Digraph(comment="Knowledge Graph")
# Add nodes
for node in kg.nodes:
dot.node(str(node.id), node.label, color=node.color)
# Add edges
for edge in kg.edges:
dot.edge(
str(edge.source),
str(edge.target),
label=edge.label,
color=edge.color or "black",
)
# Render the graph
dot.render("knowledge_graph.gv", view=True)
def create_knowledge_graph(data: str) -> Union[KnowledgeGraph, None]:
data: Dict[str, Any] = json.loads(data)
nodes = [Node(**node) for node in data["nodes"]]
edges = []
for edge in data["edges"]:
if edge.get("color") is None:
edge["color"] = "black"
edges.append(Edge(**edge))
return KnowledgeGraph(nodes=nodes, edges=edges)
if __name__ == "__main__":
import sys
args = sys.argv[1:]
from datasets import load_dataset
ds = load_dataset("distilabel-internal-testing/knowledge_graphs", split="train")
graphs = [create_knowledge_graph(g) for g in ds["generation"]]
visualize_knowledge_graph(graphs[int(args[0])])
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import wikipedia
from pydantic import BaseModel, Field
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
page = wikipedia.page(title="Transfer_learning")
class ExamQuestion(BaseModel):
question: str = Field(..., description="The question to be answered")
answer: str = Field(..., description="The correct answer to the question")
distractors: List[str] = Field(
..., description="A list of incorrect but viable answers to the question"
)
class ExamQuestions(BaseModel):
exam: List[ExamQuestion]
SYSTEM_PROMPT = """\
You are an exam writer specialized in writing exams for students.
Your goal is to create questions and answers based on the document provided, and a list of distractors, that are incorrect but viable answers to the question.
Your answer must adhere to the following format:
```
[
{
"question": "Your question",
"answer": "The correct answer to the question",
"distractors": ["wrong answer 1", "wrong answer 2", "wrong answer 3"]
},
... (more questions and answers as required)
]
```
""".strip()
with Pipeline(name="ExamGenerator") as pipeline:
load_dataset = LoadDataFromDicts(
name="load_instructions",
data=[
{
"page": page.content,
}
],
)
text_generation = TextGeneration(
name="exam_generation",
system_prompt=SYSTEM_PROMPT,
template="Generate a list of answers and questions about the document. Document:\n\n{{ page }}",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
structured_output={
"schema": ExamQuestions.model_json_schema(),
"format": "json",
},
),
input_batch_size=8,
output_mappings={"model_name": "generation_model"},
)
load_dataset >> text_generation
if __name__ == "__main__":
distiset = pipeline.run(
parameters={
text_generation.name: {
"llm": {
"generation_kwargs": {
"max_new_tokens": 2048,
}
}
}
},
use_cache=False,
)
distiset.push_to_hub("USERNAME/exam_questions")
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
from datasets import load_dataset
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import FormatTextGenerationSFT, LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
class SocialAI(TextGeneration):
follower_type: Literal["supporter", "troll", "alarmist"] = "supporter"
system_prompt: str = (
"You are an AI assistant expert at simulating user interactions. "
"You must answer as if you were a '{follower_type}', be concise answer with no more than 200 characters, nothing else."
"Here are some traits to use for your personality:\n\n"
"{traits}"
)
template: str = "You are the folowing persona:\n\n{{ persona }}\n\nWhat would you say to the following?\n\n {{ post }}"
columns: str | list[str] = ["persona", "post"]
_follower_traits: dict[str, str] = {
"supporter": (
"- Encouraging and positive\n"
"- Tends to prioritize enjoyment and relaxation\n"
"- Focuses on the present moment and short-term pleasure\n"
"- Often uses humor and playful language\n"
"- Wants to help others feel good and have fun\n"
),
"troll": (
"- Provocative and confrontational\n"
"- Enjoys stirring up controversy and conflict\n"
"- Often uses sarcasm, irony, and mocking language\n"
"- Tends to belittle or dismiss others' opinions and feelings\n"
"- Seeks to get a rise out of others and create drama\n"
),
"alarmist": (
"- Anxious and warning-oriented\n"
"- Focuses on potential risks and negative consequences\n"
"- Often uses dramatic or sensational language\n"
"- Tends to be serious and stern in tone\n"
"- Seeks to alert others to potential dangers and protect them from harm (even if it's excessive or unwarranted)\n"
),
}
def load(self) -> None:
super().load()
self.system_prompt = self.system_prompt.format(
follower_type=self.follower_type,
traits=self._follower_traits[self.follower_type],
)
posts = [
{
"post": "Hmm, ok now I'm torn: should I go for healthy chicken tacos or unhealthy beef tacos for late night cravings?"
},
{
"post": "I need to develop a training course for my company on communication skills. Need to decide how deliver it remotely."
},
{
"post": "I'm always 10 minutes late to meetups but no one's complained. Could this be annoying to them?"
},
]
personas = (
load_dataset("argilla/FinePersonas-v0.1-clustering-100k", split="train")
.shuffle()
.select(range(3))
.select_columns("persona")
.to_list()
)
data = []
for post in posts:
for persona in personas:
data.append({"post": post["post"], "persona": persona["persona"]})
with Pipeline(name="Social AI Personas") as pipeline:
loader = LoadDataFromDicts(data=data, batch_size=1)
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 256,
},
)
for follower_type in ["supporter", "troll", "alarmist"]:
follower = SocialAI(
llm=llm,
follower_type=follower_type,
name=f"{follower_type}_user",
output_mappings={"generation": f"interaction_{follower_type}"},
)
format_sft = FormatTextGenerationSFT(
name=f"format_sft_{follower_type}",
input_mappings={
"instruction": "post",
"generation": f"interaction_{follower_type}",
},
)
loader >> follower >> format_sft
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False)
distiset.push_to_hub("plaguss/FinePersonas-SocialAI-test", include_script=True)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.steps.tasks import ImageGeneration
ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))
with Pipeline(name="image_generation_pipeline") as pipeline:
igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")
img_generation = ImageGeneration(
name="flux_schnell",
image_generation_model=igm,
input_mappings={"prompt": "persona"},
)
keep_columns = KeepColumns(columns=["persona", "model_name", "image"])
img_generation >> keep_columns
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False, dataset=ds)
# Save the images as `PIL.Image.Image`
distiset = distiset.transform_columns_to_image("image")
distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
def final_velocity(initial_velocity: float, acceleration: float, time: float) -> int:
"""Calculates the final velocity of an object given its initial velocity, acceleration, and time.
Args:
initial_velocity: The initial velocity of the object.
acceleration: The acceleration of the object.
time: The time elapsed.
Returns:
The final velocity
"""
# Tool:
# {"name": "final_velocity", "description": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.", "parameters": {"initial_velocity": {"description": "The initial velocity of the object.", "type": "float"}, "acceleration": {"description": "The acceleration of the object.", "type": "float"}, "time": {"description": "The time elapsed.", "type": "float"}}}
# Answer:
# {"name": "final_velocity", "arguments": {"initial_velocity": 5, "acceleration": 1.5, "time": 40}}
return initial_velocity + acceleration * time
def permutation_count(n: int, k: int) -> int:
"""Calculates the number of permutations of k elements from a set of n elements.
Args:
n: The total number of elements in the set.
k: The number of elements to choose for the permutation.
Returns:
The number of permutations.
"""
# Tool:
# {"name": "permutation_count", "description": "Calculates the number of permutations of k elements from a set of n elements.", "parameters": {"n": {"description": "The total number of elements in the set.", "type": "int"}, "k": {"description": "The number of elements to choose for the permutation.", "type": "int"}}}
# Answer:
# {"name": "permutation_count", "arguments": {"n": 10, "k": 3}}
import math
return math.factorial(n) / math.factorial(n - k)
def getdivision(dividend: int, divisor: int) -> float:
"""Divides two numbers by making an API call to a division service.
Args:
dividend: The dividend in the division operation.
divisor: The divisor in the division operation.
Returns:
Division of the 2 numbers.
"""
# Tool:
# {"name": "getdivision", "description": "Divides two numbers by making an API call to a division service.", "parameters": {"divisor": {"description": "The divisor in the division operation.", "type": "int", "default": ""}, "dividend": {"description": "The dividend in the division operation.", "type": "int", "default": ""}}}
# Answer:
# {"name": "getdivision", "arguments": {"divisor": 25, "dividend": 100}}
return dividend / divisor
def binary_addition(a: str, b: str) -> str:
"""Adds two binary numbers and returns the result as a binary string.
Args:
a: The first binary number.
b: The second binary number.
Raises:
ValueError: On invalid binary number.
Returns:
Binary string of the sum of the two numbers.
"""
# Tool:
# {"name": "binary_addition", "description": "Adds two binary numbers and returns the result as a binary string.", "parameters": {"a": {"description": "The first binary number.", "type": "str"}, "b": {"description": "The second binary number.", "type": "str"}}}
# Answer:
# {"name": "binary_addition", "arguments": {"a": "1010", "b": "1101"}}
if not set(a).issubset("01") or not set(b).issubset("01"):
raise ValueError("Invalid binary number")
return bin(int(a, 2) + int(b, 2))[2:]
def _make_request(url: str, params: Optional[Dict[str, Any]] = None):
import requests
req = requests.get(url, params=params)
return req.json()
def swapi_planet_resource(id: str) -> Dict[str, Any]:
"""get a specific planets resource
Args:
id: identifier of the planet
Returns:
Information about the planet.
"""
# url = "https://swapi.dev/api/planets/1"
return _make_request(r"https://swapi.dev/api/planets/", params={"id": id})
def disney_character(name: str) -> Dict[str, Any]:
"""Find a specific character using this endpoint
Args:
name: Name of the character to look for.
Returns:
Infrmation about the character.
"""
# Example:
# url = "https://api.disneyapi.dev/character"
# params = {"name": "mulan"}
return _make_request(r"https://api.disneyapi.dev/character", params={"name": name})
def get_lib():
return {
"swapi_planet_resource": swapi_planet_resource,
"disney_character": disney_character,
"final_velocity": final_velocity,
"permutation_count": permutation_count,
"getdivision": getdivision,
"binary_addition": binary_addition,
}
def get_tools() -> Dict[str, Dict[str, Any]]:
"""Returns the tool representation of the functions in the library."""
# TODO: Improve the `get_json_schema`, it fails on a lot of examples.
from transformers.utils import get_json_schema
return {name: get_json_schema(func) for name, func in get_lib().items()}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, ExpandColumns
from distilabel.steps.tasks import (
FormatPRM,
MathShepherdCompleter,
MathShepherdGenerator,
)
ds_name = "openai/gsm8k"
ds = (
load_dataset(ds_name, "main", split="test")
.rename_column("question", "instruction")
.select(range(3))
)
with Pipeline(name="Math-Shepherd") as pipe:
model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct"
model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct"
llm_70B = InferenceEndpointsLLM(
model_id=model_id_8B,
tokenizer_id=model_id_8B,
generation_kwargs={"max_new_tokens": 1024, "temperature": 0.5},
)
llm_8B = InferenceEndpointsLLM(
model_id=model_id_8B,
tokenizer_id=model_id_8B,
generation_kwargs={"max_new_tokens": 2048, "temperature": 0.7},
)
generator_golden = MathShepherdGenerator(
name="golden_generator",
llm=llm_70B,
)
generator = MathShepherdGenerator(
name="generator",
llm=llm_8B,
M=5,
)
completer = MathShepherdCompleter(name="completer", llm=llm_8B, N=4)
combine = CombineOutputs()
expand = ExpandColumns(
name="expand_columns",
columns=["solutions"],
split_statistics=True,
)
formatter = FormatPRM(name="format_prm")
[generator_golden, generator] >> combine >> completer >> expand >> formatter
if __name__ == "__main__":
distiset = pipe.run(use_cache=False, dataset=ds)
distiset.push_to_hub("plaguss/test_math_shepherd_prm")
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from datasets import load_dataset
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, DataSampler, LoadDataFromDicts
from distilabel.steps.tasks import (
APIGenExecutionChecker,
APIGenGenerator,
APIGenSemanticChecker,
)
from distilabel.steps.tasks.apigen.utils import PrepareExamples, load_module_from_path
libpath = Path(__file__).parent / "lib_apigen.py"
data = [
{
"func_name": "final_velocity",
"func_desc": "Calculates the final velocity of an object given its initial velocity, acceleration, and time.",
},
{
"func_name": "permutation_count",
"func_desc": "Calculates the number of permutations of k elements from a set of n elements.",
},
{
"func_name": "getdivision",
"func_desc": "Divides two numbers by making an API call to a division service.",
},
{
"func_name": "binary_addition",
"func_desc": "Adds two binary numbers and returns the result as a binary string.",
},
{
"func_name": "swapi_planet_resource",
"func_desc": "get a specific planets resource",
},
{
"func_name": "disney_character",
"func_desc": "Find a specific character using this endpoint",
},
]
libpath_module = load_module_from_path(libpath)
tools = libpath_module.get_tools() # call get_tools()
# TODO: Add in the tools between 0 and 2 extra tools to make the task more challenging.
for row in data:
# The tools should have a mix where both the correct and irrelevant tools are present.
row.update({"tools": [tools[row["func_name"]]]})
ds_og = (
load_dataset("Salesforce/xlam-function-calling-60k", split="train")
.shuffle(seed=42)
.select(range(500))
.to_list()
)
with Pipeline(name="APIGenPipeline") as pipeline:
loader_seeds = LoadDataFromDicts(data=data)
sampler = DataSampler(
data=ds_og,
size=2,
samples=len(data),
batch_size=8,
)
prep_examples = PrepareExamples()
model_id = "meta-llama/Meta-Llama-3.1-70B-Instruct"
llm = InferenceEndpointsLLM(
model_id=model_id,
tokenizer_id=model_id,
generation_kwargs={
"temperature": 0.7,
"max_new_tokens": 2048,
},
)
apigen = APIGenGenerator(
llm=llm,
use_default_structured_output=True,
)
combine_steps = CombineOutputs()
execution_checker = APIGenExecutionChecker(libpath=str(libpath))
semantic_checker = APIGenSemanticChecker(llm=llm)
sampler >> prep_examples
(
[loader_seeds, prep_examples]
>> combine_steps
>> apigen
>> execution_checker
>> semantic_checker
)
if __name__ == "__main__":
distiset = pipeline.run()
print(distiset["default"]["train"][0])
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from pydantic import BaseModel, Field
from distilabel.models import MistralLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
class Node(BaseModel):
id: int
label: str
color: str
class Edge(BaseModel):
source: int
target: int
label: str
color: str = "black"
class KnowledgeGraph(BaseModel):
nodes: List[Node] = Field(..., default_factory=list)
edges: List[Edge] = Field(..., default_factory=list)
with Pipeline(
name="Knowledge-Graphs",
description=(
"Generate knowledge graphs to answer questions, this type of dataset can be used to "
"steer a model to answer questions with a knowledge graph."
),
) as pipeline:
sample_questions = [
"Teach me about quantum mechanics",
"Who is who in The Simpsons family?",
"Tell me about the evolution of programming languages",
]
load_dataset = LoadDataFromDicts(
name="load_instructions",
data=[
{
"system_prompt": "You are a knowledge graph expert generator. Help me understand by describing everything as a detailed knowledge graph.",
"instruction": f"{question}",
}
for question in sample_questions
],
)
text_generation = TextGeneration(
name="knowledge_graph_generation",
llm=MistralLLM(
model="open-mixtral-8x22b", structured_output={"schema": KnowledgeGraph}
),
input_batch_size=8,
output_mappings={"model_name": "generation_model"},
)
load_dataset >> text_generation
if __name__ == "__main__":
distiset = pipeline.run(
parameters={
text_generation.name: {
"llm": {"generation_kwargs": {"max_new_tokens": 2048}}
}
},
use_cache=False,
)
distiset.push_to_hub("distilabel-internal-testing/knowledge_graphs")
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from pathlib import Path
from pydantic import BaseModel, StringConstraints, conint
from typing_extensions import Annotated
from distilabel.models import LlamaCppLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
mithril = "mithril"
class Character(BaseModel):
name: Annotated[str, StringConstraints(max_length=30)]
age: conint(gt=1, lt=3000)
armor: Armor
weapon: Weapon
# Download the model with
# curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf
model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"
with Pipeline("RPG-characters") as pipeline:
system_prompt = (
"You are a leading role play gamer. You have seen thousands of different characters and their attributes."
" Please return a JSON object with common attributes of an RPG character."
)
load_dataset = LoadDataFromDicts(
name="load_instructions",
data=[
{
"system_prompt": system_prompt,
"instruction": f"Give me a character description for a {char}",
}
for char in ["dwarf", "elf", "human", "ork"]
],
)
llm = LlamaCppLLM(
model_path=str(Path.home() / model_path), # type: ignore
n_gpu_layers=-1,
n_ctx=1024,
structured_output={"format": "json", "schema": Character},
)
# Change to vLLM as such:
# llm = vLLM(
# model="teknium/OpenHermes-2.5-Mistral-7B",
# extra_kwargs={"tensor_parallel_size": 1},
# structured_output={"format": "json", "schema": Character},
# )
text_generation = TextGeneration(
name="text_generation_rpg",
llm=llm,
input_batch_size=8,
output_mappings={"model_name": "generation_model"},
)
load_dataset >> text_generation
if __name__ == "__main__":
distiset = pipeline.run(
parameters={
text_generation.name: {
"llm": {"generation_kwargs": {"max_new_tokens": 256}}
}
},
use_cache=False,
)
for num, character in enumerate(distiset["default"]["train"]["generation"]):
print(f"Character: {num}")
print(character)
# Character: 0
# {
# "name": "Gimli",
# "age": 42,
# "armor": "plate",
# "weapon": "axe" }
# Character: 1
# {"name":"Gaelen","age":600,"armor":"leather","weapon":"bow"}
# Character: 2
# {"name": "John Smith","age": 35,"armor": "leather","weapon": "sword"}
# Character: 3
# { "name": "Grug", "age": 35, "armor": "leather", "weapon": "axe"}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage
with Pipeline(name="vision_generation_pipeline") as pipeline:
loader = LoadDataFromDicts(
data=[
{
"instruction": "What’s in this image?",
"image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
}
],
)
llm = InferenceEndpointsLLM(
model_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
)
vision = TextGenerationWithImage(name="vision_gen", llm=llm, image_type="url")
loader >> vision
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False)
distiset.push_to_hub("plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct")
# Project information
site_name: Distilabel Docs
site_url: https://argilla-io.github.io/distilabel
site_author: Argilla, Inc.
site_description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs.
# Repository
repo_name: argilla-io/distilabel
repo_url: https://github.com/argilla-io/distilabel
edit_uri: edit/main/docs/
extra:
version:
provider: mike
social:
- icon: fontawesome/brands/linkedin
link: https://www.linkedin.com/company/argilla-io
- icon: fontawesome/brands/x-twitter
link: https://twitter.com/argilla_io
- icon: fontawesome/brands/youtube
link: https://www.youtube.com/channel/UCAIz8TmvQQrLqbD7sd-5S2A
- icon: fontawesome/brands/discord
link: http://hf.co/join/discord
analytics:
provider: plausible
domain: distilabel.argilla.io
feedback:
title: Was this page helpful?
ratings:
- icon: material/thumb-up-outline
name: This page was helpful
data: 1
note: >-
Thanks for your feedback!
- icon: material/thumb-down-outline
name: This page could be improved
data: 0
note: >-
Thanks for your feedback! Help us improve this page by
<a href="https://github.com/argilla-io/distilabel/issues/new/?title=[Feedback]+{title}+-+{url}" target="_blank" rel="noopener">opening a GitHub issue</a>.
extra_css:
- stylesheets/extra.css
extra_javascript:
- javascripts/mathjax.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
theme:
name: material
logo: assets/logo.svg
favicon: assets/logo.svg
icon:
repo: fontawesome/brands/github
features:
- navigation.instant
- navigation.sections
- navigation.tabs
- navigation.footer
- navigation.top
- navigation.tracking
- navigation.path
- header.autohide # header disappears as you scroll
- content.code.copy
- content.code.annotate
- content.tabs.link
- content.action.edit
- toc.follow
- search.suggest
- search.highlight
- search.share
palette:
- media: "(prefers-color-scheme)"
primary: white
toggle:
icon: material/brightness-auto
name: Switch to light mode
- media: "(prefers-color-scheme: light)"
scheme: default
primary: custom
toggle:
icon: material/brightness-7
name: Switch to dark mode
# Palette toggle for dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: custom
toggle:
icon: material/brightness-4
name: Switch to system preference
watch:
- src/distilabel
strict: true
# Extensions
markdown_extensions:
- attr_list
- md_in_html
- admonition
- pymdownx.superfences
- pymdownx.arithmatex:
generic: true
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.keys
- pymdownx.superfences:
custom_fences:
- name: mermaid
class: mermaid
format: !!python/name:pymdownx.superfences.fence_code_format
- pymdownx.snippets:
check_paths: true
base_path: [examples/, docs/, "."]
- pymdownx.details
- pymdownx.tabbed:
alternate_style: true
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
- footnotes
- toc:
permalink: true
plugins:
- search
- autorefs # Cross-links to headings
- gen-files:
scripts:
- docs/scripts/gen_popular_issues.py
- section-index
- mkdocstrings:
handlers:
python:
setup_commands:
- import sys; sys.path.insert(0, 'src') # API references are built from source
options:
show_inheritance_diagram: false
show_source: true # include source code
# Headings
heading_level: 3
show_root_heading: true # show the python path of the class
show_root_toc_entry: true # show the toc entry for the root class
show_root_full_path: false # display "diffrax.asdf" not just "asdf"
show_object_full_path: false # display "diffrax.asdf" not just "asdf"
show_symbol_type_heading: true
show_symbol_type_toc: true
# Members
inherited_members: false # allow looking up inherited methods
members_order: source # order methods according to their order of definition in the source code, not alphabetical order
show_labels: true
# Docstring
docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
show_if_no_docstring: false
# Signature
separate_signature: false
show_signature_annotations: false
- social
- mknotebooks
- material-plausible
- glightbox
- distilabel/components-gallery:
add_after_page: How-to guides
nav:
- Distilabel: "index.md"
- Getting started:
- Quickstart: "sections/getting_started/quickstart.md"
- Installation: "sections/getting_started/installation.md"
- FAQ: "sections/getting_started/faq.md"
- How-to guides:
- "sections/how_to_guides/index.md"
- Basic:
- Steps for processing data:
- "sections/how_to_guides/basic/step/index.md"
- GeneratorStep: "sections/how_to_guides/basic/step/generator_step.md"
- GlobalStep: "sections/how_to_guides/basic/step/global_step.md"
- Tasks for generating and judging with LLMs:
- "sections/how_to_guides/basic/task/index.md"
- GeneratorTask: "sections/how_to_guides/basic/task/generator_task.md"
- ImageTask: "sections/how_to_guides/basic/task/image_task.md"
- Executing Tasks with LLMs: "sections/how_to_guides/basic/llm/index.md"
- Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md"
- Advanced:
- The Distiset dataset object: "sections/how_to_guides/advanced/distiset.md"
- Pipeline cache: "sections/how_to_guides/advanced/caching.md"
- Exporting data to Argilla: "sections/how_to_guides/advanced/argilla.md"
- Structured data generation: "sections/how_to_guides/advanced/structured_generation.md"
- Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md"
- Specifying requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md"
- Load groups and execution stages: "sections/how_to_guides/advanced/load_groups_and_execution_stages.md"
- Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md"
- Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md"
- Assigning resources to a step: "sections/how_to_guides/advanced/assigning_resources_to_step.md"
- Saving step generated artifacts: "sections/how_to_guides/advanced/saving_step_generated_artifacts.md"
- Serving an LLM for sharing it between several tasks: "sections/how_to_guides/advanced/serving_an_llm_for_reuse.md"
- Scaling and distributing a pipeline with Ray: "sections/how_to_guides/advanced/scaling_with_ray.md"
- Tutorials:
- "sections/pipeline_samples/index.md"
- Tutorials:
- Generate a preference dataset: "sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb"
- Clean an existing preference dataset: "sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb"
- Synthetic data generation for fine-tuning custom retrieval and reranking models: "sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb"
- Generate synthetic text classification data: "sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb"
- Papers:
- DeepSeek Prover: "sections/pipeline_samples/papers/deepseek_prover.md"
- DEITA: "sections/pipeline_samples/papers/deita.md"
- Instruction Backtranslation: "sections/pipeline_samples/papers/instruction_backtranslation.md"
- Prometheus 2: "sections/pipeline_samples/papers/prometheus.md"
- UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md"
- APIGen: "sections/pipeline_samples/papers/apigen.md"
- CLAIR: "sections/pipeline_samples/papers/clair.md"
- Math Shepherd: "sections/pipeline_samples/papers/math_shepherd.md"
- Examples:
- Benchmarking with distilabel: "sections/pipeline_samples/examples/benchmarking_with_distilabel.md"
- Structured generation with outlines: "sections/pipeline_samples/examples/llama_cpp_with_outlines.md"
- Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md"
- Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md"
- Create questions and answers for a exam: "sections/pipeline_samples/examples/exam_questions.md"
- Image generation with distilabel: "sections/pipeline_samples/examples/image_generation.md"
- Text generation with images in distilabel: "sections/pipeline_samples/examples/text_generation_with_image.md"
- API Reference:
- Step:
- "api/step/index.md"
- GeneratorStep: "api/step/generator_step.md"
- GlobalStep: "api/step/global_step.md"
- "@step": "api/step/decorator.md"
- StepResources: "api/step/resources.md"
- Step Gallery:
- Argilla: "api/step_gallery/argilla.md"
- Hugging Face: "api/step_gallery/hugging_face.md"
- Columns: "api/step_gallery/columns.md"
- Extra: "api/step_gallery/extra.md"
- Task:
- "api/task/index.md"
- GeneratorTask: "api/task/generator_task.md"
- Task Gallery: "api/task/task_gallery.md"
- LLM:
- "api/models/llm/index.md"
- LLM Gallery: "api/models/llm/llm_gallery.md"
- Embedding:
- "api/models/embedding/index.md"
- Embedding Gallery: "api/models/embedding/embedding_gallery.md"
- ImageGenerationModels:
- "api/models/image_generation/index.md"
- Image Generation Gallery: "api/models/image_generation/image_generation_gallery.md"
- Pipeline:
- "api/pipeline/index.md"
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
- Step Wrapper: "api/pipeline/step_wrapper.md"
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
- Exceptions: "api/exceptions.md"
- Errors: "api/errors.md"
- Distiset: "api/distiset.md"
- CLI: "api/cli.md"
- Types: "api/typing.md"
- Community:
- sections/community/index.md
- How to contribute?: sections/community/contributor.md
- Developer Documentation: sections/community/developer_documentation.md
- Issue dashboard: sections/community/popular_issues.md
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "distilabel"
description = "Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs."
readme = "README.md"
requires-python = ">=3.9"
license = "Apache-2.0"
keywords = ["llm", "annotation", "alignment", "synthetic", "data", "rlaif"]
authors = [{ name = "Argilla", email = "admin@argilla.io" }]
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
# Bump `datasets` to support `load_dataset` from cache
# Ref https://github.com/huggingface/datasets/releases/tag/2.16.0
"datasets >= 2.16.0",
"httpx >= 0.25.2",
"importlib-resources >= 6.1.1; python_version < '3.9'",
"Jinja2 >= 3.1.2",
"multiprocess >= 0.70",
"nest-asyncio >= 1.6.0",
"networkx >= 3.0",
"pydantic >= 2.0",
"rich >= 13.5.0",
"scipy >= 1.10.0",
"typer >= 0.9.0",
"tblib >= 3.0.0",
"orjson >= 3.10.0",
"universal_pathlib >= 0.2.2",
"portalocker >= 2.8.2",
"setuptools",
]
dynamic = ["version"]
[project.scripts]
distilabel = "distilabel.cli.app:app"
[project.entry-points."mkdocs.plugins"]
"distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin"
[project.optional-dependencies]
dev = ["ruff == 0.8.1", "pre-commit >= 3.5.0"]
docs = [
"mkdocs-material >=9.5.17",
"mkdocstrings[python] >= 0.24.0",
"mkdocs-literate-nav >= 0.6.1",
"mkdocs-section-index >= 0.3.8",
"mkdocs-gen-files >= 0.5.0",
"mkdocs-glightbox >= 0.4.0",
"material-plausible-plugin>=0.2.0",
"mike >= 2.0.0",
"Pillow >= 9.5.0",
"CairoSVG >= 2.7.1",
"mknotebooks >= 0.8.0",
"pandas >= 2.0",
"tabulate>=0.9.0",
]
tests = [
"pytest >= 7.4.0",
"pytest-asyncio",
"nest-asyncio",
"pytest-timeout",
"pytest-codspeed",
]
# Optional LLMs, integrations, etc
anthropic = ["anthropic >= 0.20.0"]
argilla = ["argilla >= 2.0.0", "ipython"]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
instructor = ["instructor >= 1.2.3"]
litellm = ["litellm >= 1.30.0"]
llama-cpp = ["llama-cpp-python >= 0.2.0"]
mistralai = ["mistralai >= 1.0.0"]
ollama = ["ollama >= 0.1.7"]
openai = ["openai >= 1.0.0"]
outlines = ["outlines >= 0.0.40", "numba >= 0.54.0"]
ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = [
"vllm >= 0.5.3",
"filelock >= 3.13.4"
]
sentence-transformers = ["sentence-transformers >= 3.0.0"]
faiss-cpu = ["faiss-cpu >= 1.8.0"]
faiss-gpu = ["faiss-gpu >= 1.7.2"]
text-clustering = [
"umap-learn >= 0.5.6",
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
]
mlx = ["mlx >= 0.21.0", "mlx-lm >= 0.21.0, < 0.22.0"]
vision = ["Pillow >= 10.3.0"] # To work with images.
# minhash
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
[project.urls]
Documentation = "https://distilabel.argilla.io/"
Issues = "https://github.com/argilla/distilabel/issues"
Source = "https://github.com/argilla/distilabel"
[tool.hatch.version]
path = "src/distilabel/__init__.py"
[tool.ruff]
line-length = 88
exclude = ["docs"]
[tool.ruff.lint]
select = ["E", "W", "F", "I", "C", "B"]
ignore = ["E501", "B905", "B008"]
extend-select = ["RUF022"]
[tool.pytest.ini_options]
testpaths = ["tests"]
#!/bin/bash
set -e
echo "Updating system and installing build dependencies..."
sudo apt-get update -y
sudo apt-get install -y gcc-12 g++-12 libnuma-dev cmake libdnnl-dev
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
echo "Python version:"
python --version
echo "Python executable location:"
which python
echo "Installing Python build dependencies..."
python -m pip install --upgrade pip
python -m pip install wheel packaging ninja "setuptools>=49.4.0" numpy setuptools-scm
echo "Cloning 'vllm-project/vllm' GitHub repository..."
git clone https://github.com/vllm-project/vllm.git
cd vllm || exit
git fetch --tags
latest_tag=$(git describe --tags "$(git rev-list --tags --max-count=1)")
echo "Checking out to '$latest_tag' tag..."
git checkout "$latest_tag"
echo "Installing vLLM CPU requirements..."
python -m pip install -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
echo "Installing vLLM for CPU..."
export CMAKE_ARGS="-DPYTHON_EXECUTABLE=$(which python) -DPYTHON_INCLUDE_DIR=$(python -c "from sysconfig import get_path; print(get_path('include'))") -DPYTHON_LIBRARY=$(python -c "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))")"
echo "CMake args: $CMAKE_ARGS"
VLLM_TARGET_DEVICE=cpu python setup.py install
echo "Installation complete!"
#!/bin/bash
set -e
python_version=$(python -c "import sys; print(sys.version_info[:2])")
python -m pip install uv
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"
if [ "${python_version}" != "(3, 12)" ]; then
uv pip install --system -e .[ray]
fi
./scripts/install_cpu_vllm.sh
uv pip install --system -e ".[dev,tests]"
#!/bin/bash
set -e
python_version=$(python -c "import sys; print(sys.version_info[:2])")
python -m pip install uv
uv pip install --system -e ".[docs]"
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from rich import traceback as rich_traceback
__version__ = "1.5.3"
rich_traceback.install(show_locals=True)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.cli.app import app
if __name__ == "__main__":
app(prog_name="distilabel")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment