Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import json
import logging
import os
import sys
import time
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from distilabel.constants import SIGINT_HANDLER_CALLED_ENV_NAME
from distilabel.errors import DistilabelNotImplementedError, DistilabelUserError
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersModelMixin,
)
from distilabel.utils.docstring import parse_google_docstring
from distilabel.utils.notebook import in_notebook
from distilabel.utils.serialization import _Serializable
if TYPE_CHECKING:
from logging import Logger
from distilabel.typing import (
FormattedInput,
GenerateOutput,
HiddenState,
InstructorStructuredOutputType,
StandardInput,
StructuredOutputType,
)
from distilabel.utils.docstring import Docstring
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
class LLM(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC):
"""Base class for `LLM`s to be used in `distilabel` framework.
To implement an `LLM` subclass, you need to subclass this class and implement:
- `load` method to load the `LLM` if needed. Don't forget to call `super().load()`,
so the `_logger` attribute is initialized.
- `model_name` property to return the model name used for the LLM.
- `generate` method to generate `num_generations` per input in `inputs`.
Attributes:
generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
methods within each `LLM`.
use_offline_batch_generation: whether to use the `offline_batch_generate` method to
generate the responses.
offline_batch_generation_block_until_done: if provided, then polling will be done until
the `ofline_batch_generate` method is able to retrieve the results. The value indicate
the time to wait between each polling.
jobs_ids: the job ids generated by the `offline_batch_generate` method. This attribute
is used to store the job ids generated by the `offline_batch_generate` method
so later they can be used to retrieve the results. It is not meant to be set by
the user.
_logger: the logger to be used for the `LLM`. It will be initialized when the `load`
method is called.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
protected_namespaces=(),
validate_default=True,
validate_assignment=True,
extra="forbid",
)
generation_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
default_factory=dict,
description="The kwargs to be propagated to either `generate` or `agenerate`"
" methods within each `LLM`.",
)
use_offline_batch_generation: Optional[RuntimeParameter[bool]] = Field(
default=False,
description="Whether to use the `offline_batch_generate` method to generate"
" the responses.",
)
offline_batch_generation_block_until_done: Optional[RuntimeParameter[int]] = Field(
default=None,
description="If provided, then polling will be done until the `ofline_batch_generate`"
" method is able to retrieve the results. The value indicate the time to wait between"
" each polling.",
)
jobs_ids: Union[Tuple[str, ...], None] = Field(default=None)
_logger: "Logger" = PrivateAttr(None)
def load(self) -> None:
"""Method to be called to initialize the `LLM`, its logger and optionally the
structured output generator."""
self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")
def unload(self) -> None:
"""Method to be called to unload the `LLM` and release any resources."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
pass
def get_generation_kwargs(self) -> Dict[str, Any]:
"""Returns the generation kwargs to be used for the generation. This method can
be overridden to provide a more complex logic for the generation kwargs.
Returns:
The kwargs to be used for the generation.
"""
return self.generation_kwargs # type: ignore
@abstractmethod
def generate(
self,
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Abstract method to be implemented by each LLM to generate `num_generations`
per input in `inputs`.
Args:
inputs: the list of inputs to generate responses for which follows OpenAI's
API format:
```python
[
{"role": "system", "content": "You're a helpful assistant..."},
{"role": "user", "content": "Give a template email for B2B communications..."},
{"role": "assistant", "content": "Sure, here's a template you can use..."},
{"role": "user", "content": "Modify the second paragraph..."}
]
```
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
"""
pass
def generate_outputs(
self,
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Generates outputs for the given inputs using either `generate` method or the
`offine_batch_generate` method if `use_offline_
"""
if self.use_offline_batch_generation:
if self.offline_batch_generation_block_until_done is not None:
return self._offline_batch_generate_polling(
inputs=inputs,
num_generations=num_generations,
**kwargs,
)
# This will raise `DistilabelOfflineBatchGenerationNotFinishedException` right away
# if the batch generation is not finished.
return self.offline_batch_generate(
inputs=inputs,
num_generations=num_generations,
**kwargs,
)
return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)
def _offline_batch_generate_polling(
self,
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Method to poll the `offline_batch_generate` method until the batch generation
is finished.
Args:
inputs: the list of inputs to generate responses for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the generations for each input.
"""
while True:
try:
return self.offline_batch_generate(
inputs=inputs,
num_generations=num_generations,
**kwargs,
)
except DistilabelOfflineBatchGenerationNotFinishedException as e:
self._logger.info(
f"Waiting for the offline batch generation to finish: {e}. Sleeping"
f" for {self.offline_batch_generation_block_until_done} seconds before"
" trying to get the results again."
)
# When running a `Step` in a child process, SIGINT is overridden so the child
# process doesn't stop when the parent process receives a SIGINT signal.
# The new handler sets an environment variable that is checked here to stop
# the polling.
if os.getenv(SIGINT_HANDLER_CALLED_ENV_NAME) is not None:
self._logger.info(
"Received a KeyboardInterrupt. Stopping polling for checking if the"
" offline batch generation is finished..."
)
raise e
time.sleep(self.offline_batch_generation_block_until_done) # type: ignore
except KeyboardInterrupt as e:
# This is for the case the `LLM` is being executed outside a pipeline
self._logger.info(
"Received a KeyboardInterrupt. Stopping polling for checking if the"
" offline batch generation is finished..."
)
raise DistilabelOfflineBatchGenerationNotFinishedException(
jobs_ids=self.jobs_ids # type: ignore
) from e
def get_last_hidden_states(
self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
"""Method to get the last hidden states of the model for a list of inputs.
Args:
inputs: the list of inputs to get the last hidden states from.
Returns:
A list containing the last hidden state for each sequence using a NumPy array
with shape [num_tokens, hidden_size].
"""
# TODO: update to use `DistilabelNotImplementedError`
raise NotImplementedError(
f"Method `get_last_hidden_states` is not implemented for `{self.__class__.__name__}`"
)
def _prepare_structured_output(
self, structured_output: "StructuredOutputType"
) -> Union[Any, None]:
"""Method in charge of preparing the structured output generator.
By default will raise a `NotImplementedError`, subclasses that allow it must override this
method with the implementation.
Args:
structured_output: the config to prepare the guided generation.
Returns:
The structure to be used for the guided generation.
"""
# TODO: update to use `DistilabelNotImplementedError`
raise NotImplementedError(
f"Guided generation is not implemented for `{type(self).__name__}`"
)
def offline_batch_generate(
self,
inputs: Union[List["FormattedInput"], None] = None,
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Method to generate a list of outputs for the given inputs using an offline batch
generation method to be implemented by each `LLM`.
This method should create jobs the first time is called and store the job ids, so
the second and subsequent calls can retrieve the results of the batch generation.
If subsequent calls are made before the batch generation is finished, then the method
should raise a `DistilabelOfflineBatchGenerationNotFinishedException`. This exception
will be handled automatically by the `Pipeline` which will store all the required
information for recovering the pipeline execution when the batch generation is finished.
Args:
inputs: the list of inputs to generate responses for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the generations for each input.
"""
raise DistilabelNotImplementedError(
f"`offline_batch_generate` is not implemented for `{self.__class__.__name__}`",
page="sections/how_to_guides/advanced/offline-batch-generation/",
)
class AsyncLLM(LLM):
"""Abstract class for asynchronous LLMs, so as to benefit from the async capabilities
of each LLM implementation. This class is meant to be subclassed by each LLM, and the
method `agenerate` needs to be implemented to provide the asynchronous generation of
responses.
Attributes:
_event_loop: the event loop to be used for the asynchronous generation of responses.
"""
_num_generations_param_supported = True
_event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
_new_event_loop: bool = PrivateAttr(default=False)
@property
def generate_parameters(self) -> List[inspect.Parameter]:
"""Returns the parameters of the `agenerate` method.
Returns:
A list containing the parameters of the `agenerate` method.
"""
return list(inspect.signature(self.agenerate).parameters.values())
@cached_property
def generate_parsed_docstring(self) -> "Docstring":
"""Returns the parsed docstring of the `agenerate` method.
Returns:
The parsed docstring of the `agenerate` method.
"""
return parse_google_docstring(self.agenerate)
@property
def event_loop(self) -> "asyncio.AbstractEventLoop":
if self._event_loop is None:
try:
self._event_loop = asyncio.get_running_loop()
if self._event_loop.is_closed():
self._event_loop = asyncio.new_event_loop() # type: ignore
self._new_event_loop = True
except RuntimeError:
self._event_loop = asyncio.new_event_loop()
self._new_event_loop = True
asyncio.set_event_loop(self._event_loop)
return self._event_loop
@abstractmethod
async def agenerate(
self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
) -> "GenerateOutput":
"""Method to generate a `num_generations` responses for a given input asynchronously,
and executed concurrently in `generate` method.
"""
pass
async def _agenerate(
self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any
) -> List["GenerateOutput"]:
"""Internal function to concurrently generate responses for a list of inputs.
Args:
inputs: the list of inputs to generate responses for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the generations for each input.
"""
if self._num_generations_param_supported:
tasks = [
asyncio.create_task(
self.agenerate(
input=input, num_generations=num_generations, **kwargs
)
)
for input in inputs
]
result = await asyncio.gather(*tasks)
return result
tasks = [
asyncio.create_task(self.agenerate(input=input, **kwargs))
for input in inputs
for _ in range(num_generations)
]
outputs = await asyncio.gather(*tasks)
return merge_responses(outputs, n=num_generations)
def generate(
self,
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Method to generate a list of responses asynchronously, returning the output
synchronously awaiting for the response of each input sent to `agenerate`.
Args:
inputs: the list of inputs to generate responses for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the generations for each input.
"""
return self.event_loop.run_until_complete(
self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
)
def __del__(self) -> None:
"""Closes the event loop when the object is deleted."""
if sys.meta_path is None:
return
if self._new_event_loop:
if self._event_loop.is_running():
self._event_loop.stop()
self._event_loop.close()
@staticmethod
def _prepare_structured_output( # type: ignore
structured_output: "InstructorStructuredOutputType",
client: Any = None,
framework: Optional[str] = None,
) -> Dict[str, Union[str, Any]]:
"""Wraps the client and updates the schema to work store it internally as a json schema.
Args:
structured_output: The configuration dict to prepare the structured output.
client: The client to wrap to generate structured output. Implemented to work
with `instructor`.
framework: The name of the framework.
Returns:
A dictionary containing the wrapped client and the schema to update the structured_output
variable in case it is a pydantic model.
"""
from distilabel.steps.tasks.structured_outputs.instructor import (
prepare_instructor,
)
result = {}
client = prepare_instructor(
client,
mode=structured_output.get("mode"),
framework=framework, # type: ignore
)
result["client"] = client
schema = structured_output.get("schema")
if not schema:
raise DistilabelUserError(
f"The `structured_output` argument must contain a schema: {structured_output}",
page="sections/how_to_guides/advanced/structured_generation/#instructor",
)
if inspect.isclass(schema) and issubclass(schema, BaseModel):
# We want a json schema for the serialization, but instructor wants a pydantic BaseModel.
structured_output["schema"] = schema.model_json_schema() # type: ignore
result["structured_output"] = structured_output
return result
@staticmethod
def _prepare_kwargs(
arguments: Dict[str, Any], structured_output: Dict[str, Any]
) -> Dict[str, Any]:
"""Helper method to update the kwargs with the structured output configuration,
used in case they are defined.
Args:
arguments: The arguments that would be passed to the LLM as **kwargs.
to update with the structured output configuration.
structured_outputs: The structured output configuration to update the arguments.
Returns:
kwargs updated with the special arguments used by `instructor`.
"""
# We can deal with json schema or BaseModel, but we need to convert it to a BaseModel
# for the Instructor client.
schema = structured_output.get("schema", {})
# If there's already a pydantic model, we don't need to do anything,
# otherwise, try to obtain one.
if not (inspect.isclass(schema) and issubclass(schema, BaseModel)):
from distilabel.steps.tasks.structured_outputs.utils import (
json_schema_to_model,
)
if isinstance(schema, str):
# In case it was saved in the dataset as a string.
schema = json.loads(schema)
try:
schema = json_schema_to_model(schema)
except Exception as e:
raise ValueError(
f"Failed to convert the schema to a pydantic model, the model is too complex currently: {e}"
) from e
arguments.update(
**{
"response_model": schema,
"max_retries": structured_output.get("max_retries", 1),
},
)
return arguments
def merge_responses(
responses: List["GenerateOutput"], n: int = 1
) -> List["GenerateOutput"]:
"""Helper function to group the responses from `LLM.agenerate` method according
to the number of generations requested.
Args:
responses: the responses from the `LLM.agenerate` method.
n: number of responses to group together. Defaults to 1.
Returns:
List of merged responses, where each merged response contains n generations
and their corresponding statistics.
"""
if not responses:
return []
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield list(islice(lst, i, i + n))
extra_keys = [
key for key in responses[0].keys() if key not in ("generations", "statistics")
]
result = []
for group in chunks(responses, n):
merged = {
"generations": [],
"statistics": {"input_tokens": [], "output_tokens": []},
}
for response in group:
merged["generations"].append(response["generations"][0])
# Merge statistics
for key in response["statistics"]:
if key not in merged["statistics"]:
merged["statistics"][key] = []
merged["statistics"][key].append(response["statistics"][key][0])
# Merge extra keys returned by the `LLM`
for extra_key in extra_keys:
if extra_key not in merged:
merged[extra_key] = []
merged[extra_key].append(response[extra_key][0])
result.append(merged)
return result
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import (
TYPE_CHECKING,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import orjson
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.typing import (
FormattedInput,
GenerateOutput,
InstructorStructuredOutputType,
)
if TYPE_CHECKING:
from cohere import AsyncClient, ChatMessage, Message
from pydantic import BaseModel
from tokenizers import Tokenizer
from distilabel.typing import LLMStatistics
_COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY"
class CohereLLM(AsyncLLM):
"""Cohere API implementation using the async client for concurrent text generation.
Attributes:
model: the name of the model from the Cohere API to use for the generation.
base_url: the base URL to use for the Cohere API requests. Defaults to
`"https://api.cohere.ai/v1"`.
api_key: the API key to authenticate the requests to the Cohere API. Defaults to
the value of the `COHERE_API_KEY` environment variable.
timeout: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
client_name: the name of the client to use for the API requests. Defaults to
`"distilabel"`.
structured_output: a dictionary containing the structured output configuration configuration
using `instructor`. You can take a look at the dictionary structure in
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
_ChatMessage: the `ChatMessage` class from the `cohere` package.
_aclient: the `AsyncClient` client from the `cohere` package.
Runtime parameters:
- `base_url`: the base URL to use for the Cohere API requests. Defaults to
`"https://api.cohere.ai/v1"`.
- `api_key`: the API key to authenticate the requests to the Cohere API. Defaults
to the value of the `COHERE_API_KEY` environment variable.
- `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
- `client_name`: the name of the client to use for the API requests. Defaults to
`"distilabel"`.
Examples:
Generate text:
```python
from distilabel.models.llms import CohereLLM
llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Generate structured data:
```python
from pydantic import BaseModel
from distilabel.models.llms import CohereLLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = CohereLLM(
model="CohereForAI/c4ai-command-r-plus",
api_key="api.key",
structured_output={"schema": User}
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
model: str
base_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv(
"COHERE_BASE_URL", "https://api.cohere.ai/v1"
),
description="The base URL to use for the Cohere API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_COHERE_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Cohere API.",
)
timeout: RuntimeParameter[int] = Field(
default=120,
description="The maximum time in seconds to wait for a response from the API.",
)
client_name: RuntimeParameter[str] = Field(
default="distilabel",
description="The name of the client to use for the API requests.",
)
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
Field(
default=None,
description="The structured output format to use across all the generations.",
)
)
_num_generations_param_supported = False
_ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
_aclient: "AsyncClient" = PrivateAttr(...)
_tokenizer: "Tokenizer" = PrivateAttr(...)
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
def load(self) -> None:
"""Loads the `AsyncClient` client from the `cohere` package."""
super().load()
try:
from cohere import AsyncClient, ChatMessage
except ImportError as ie:
raise ImportError(
"The `cohere` package is required to use the `CohereLLM` class."
) from ie
self._ChatMessage = ChatMessage
self._aclient = AsyncClient(
api_key=self.api_key.get_secret_value(), # type: ignore
client_name=self.client_name,
base_url=self.base_url,
timeout=self.timeout,
)
if self.structured_output:
result = self._prepare_structured_output(
structured_output=self.structured_output,
client=self._aclient,
framework="cohere",
)
self._aclient = result.get("client") # type: ignore
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
from cohere.manually_maintained.tokenizers import get_hf_tokenizer
self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)
def _format_chat_to_cohere(
self, input: "FormattedInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
"""Formats the chat input to the Cohere Chat API conversational format.
Args:
input: The chat input to format.
Returns:
A tuple containing the system, chat history, and message.
"""
system = None
message = None
chat_history = []
for item in input:
role = item["role"]
content = item["content"]
if role == "system":
system = content
elif role == "user":
message = content
elif role == "assistant":
if message is None:
raise ValueError(
"An assistant message but be preceded by a user message."
)
chat_history.append(self._ChatMessage(role="USER", message=message)) # type: ignore
chat_history.append(self._ChatMessage(role="CHATBOT", message=content)) # type: ignore
message = None
if message is None:
raise ValueError("The chat input must end with a user message.")
return system, chat_history, message
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[float] = None,
seed: Optional[float] = None,
stop_sequences: Optional[Sequence[str]] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
raw_prompting: Optional[bool] = None,
) -> GenerateOutput:
"""Generates a response from the LLM given an input.
Args:
input: a single input in chat format to generate responses for.
temperature: the temperature to use for the generation. Defaults to `None`.
max_tokens: the maximum number of new tokens that the model will generate.
Defaults to `None`.
k: the number of highest probability vocabulary tokens to keep for the generation.
Defaults to `None`.
p: the nucleus sampling probability to use for the generation. Defaults to
`None`.
seed: the seed to use for the generation. Defaults to `None`.
stop_sequences: a list of sequences to use as stopping criteria for the generation.
Defaults to `None`.
frequency_penalty: the frequency penalty to use for the generation. Defaults
to `None`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`None`.
raw_prompting: a flag to use raw prompting for the generation. Defaults to
`None`.
Returns:
The generated response from the Cohere API model.
"""
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
result = self._prepare_structured_output(
structured_output=structured_output, # type: ignore
client=self._aclient,
framework="cohere",
)
self._aclient = result.get("client") # type: ignore
if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
system, chat_history, message = self._format_chat_to_cohere(input)
kwargs = {
"message": message,
"model": self.model,
"preamble": system,
"chat_history": chat_history,
"temperature": temperature,
"max_tokens": max_tokens,
"k": k,
"p": p,
"seed": seed,
"stop_sequences": stop_sequences,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"raw_prompting": raw_prompting,
}
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs) # type: ignore
if structured_output:
return prepare_output(
[response.model_dump_json()],
**self._get_llm_statistics(
input, orjson.dumps(response.model_dump_json()).decode("utf-8")
), # type: ignore
)
if (text := response.text) == "":
self._logger.warning( # type: ignore
f"Received no response using Cohere client (model: '{self.model}')."
f" Finish reason was: {response.finish_reason}"
)
return prepare_output(
[None],
**self._get_llm_statistics(input, ""),
)
return prepare_output(
[text],
**self._get_llm_statistics(input, text),
)
def _get_llm_statistics(
self, input: FormattedInput, output: str
) -> "LLMStatistics":
return {
"input_tokens": [compute_tokens(input, self._tokenizer.encode)],
"output_tokens": [compute_tokens(output, self._tokenizer.encode)],
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Optional
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import prepare_output
from distilabel.steps.base import RuntimeParameter
from distilabel.typing import (
FormattedInput,
GenerateOutput,
InstructorStructuredOutputType,
)
if TYPE_CHECKING:
from groq import AsyncGroq
from groq.types.chat.chat_completion import ChatCompletion
from distilabel.typing import LLMStatistics
_GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL"
_GROQ_API_KEY_ENV_VAR_NAME = "GROQ_API_KEY"
class GroqLLM(AsyncLLM):
"""Groq API implementation using the async client for concurrent text generation.
Attributes:
model: the name of the model from the Groq API to use for the generation.
base_url: the base URL to use for the Groq API requests. Defaults to
`"https://api.groq.com"`.
api_key: the API key to authenticate the requests to the Groq API. Defaults to
the value of the `GROQ_API_KEY` environment variable.
max_retries: the maximum number of times to retry the request to the API before
failing. Defaults to `2`.
timeout: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
structured_output: a dictionary containing the structured output configuration configuration
using `instructor`. You can take a look at the dictionary structure in
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
_api_key_env_var: the name of the environment variable to use for the API key.
_aclient: the `AsyncGroq` client from the `groq` package.
Runtime parameters:
- `base_url`: the base URL to use for the Groq API requests. Defaults to
`"https://api.groq.com"`.
- `api_key`: the API key to authenticate the requests to the Groq API. Defaults to
the value of the `GROQ_API_KEY` environment variable.
- `max_retries`: the maximum number of times to retry the request to the API before
failing. Defaults to `2`.
- `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
Examples:
Generate text:
```python
from distilabel.models.llms import GroqLLM
llm = GroqLLM(model="llama3-70b-8192")
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Generate structured data:
```python
from pydantic import BaseModel
from distilabel.models.llms import GroqLLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = GroqLLM(
model="llama3-70b-8192",
api_key="api.key",
structured_output={"schema": User}
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
model: str
base_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv(
_GROQ_API_BASE_URL_ENV_VAR_NAME, "https://api.groq.com"
),
description="The base URL to use for the Groq API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_GROQ_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Groq API.",
)
max_retries: RuntimeParameter[int] = Field(
default=2,
description="The maximum number of times to retry the request to the API before"
" failing.",
)
timeout: RuntimeParameter[int] = Field(
default=120,
description="The maximum time in seconds to wait for a response from the API.",
)
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
Field(
default=None,
description="The structured output format to use across all the generations.",
)
)
_num_generations_param_supported = False
_api_key_env_var: str = PrivateAttr(_GROQ_API_KEY_ENV_VAR_NAME)
_aclient: Optional["AsyncGroq"] = PrivateAttr(...)
def load(self) -> None:
"""Loads the `AsyncGroq` client to benefit from async requests."""
super().load()
try:
from groq import AsyncGroq
except ImportError as ie:
raise ImportError(
"Groq Python client is not installed. Please install it using"
' `pip install "distilabel[groq]"`.'
) from ie
if self.api_key is None:
raise ValueError(
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)
self._aclient = AsyncGroq(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(),
max_retries=self.max_retries, # type: ignore
timeout=self.timeout,
)
if self.structured_output:
result = self._prepare_structured_output(
structured_output=self.structured_output,
client=self._aclient,
framework="groq",
)
self._aclient = result.get("client") # type: ignore
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
seed: Optional[int] = None,
max_new_tokens: int = 128,
temperature: float = 1.0,
top_p: float = 1.0,
stop: Optional[str] = None,
) -> "GenerateOutput":
"""Generates `num_generations` responses for the given input using the Groq async
client.
Args:
input: a single input in chat format to generate responses for.
seed: the seed to use for the generation. Defaults to `None`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
stop: the stop sequence to use for the generation. Defaults to `None`.
Returns:
A list of lists of strings containing the generated responses for each input.
References:
- https://console.groq.com/docs/text-chat
"""
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
result = self._prepare_structured_output(
structured_output=structured_output,
client=self._aclient,
framework="groq",
)
self._aclient = result.get("client")
if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
kwargs = {
"messages": input, # type: ignore
"model": self.model,
"seed": seed,
"temperature": temperature,
"max_tokens": max_new_tokens,
"top_p": top_p,
"stream": False,
"stop": stop,
}
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
if structured_output:
return prepare_output(
[completion.model_dump_json()],
**self._get_llm_statistics(completion._raw_response),
)
generations = []
for choice in completion.choices:
if (content := choice.message.content) is None:
self._logger.warning( # type: ignore
f"Received no response using the Groq client (model: '{self.model}')."
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
return prepare_output(generations, **self._get_llm_statistics(completion))
@staticmethod
def _get_llm_statistics(completion: "ChatCompletion") -> "LLMStatistics":
return {
"input_tokens": [completion.usage.prompt_tokens if completion else 0],
"output_tokens": [completion.usage.completion_tokens if completion else 0],
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
from distilabel.models.llms.huggingface.transformers import TransformersLLM
__all__ = ["InferenceEndpointsLLM", "TransformersLLM"]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
import warnings
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
from pydantic import (
Field,
PositiveInt,
ValidationError,
model_validator,
validate_call,
)
from pydantic._internal._model_construction import ModelMetaclass
from typing_extensions import Annotated
from distilabel.models.base_clients.inference_endpoints import (
InferenceEndpointsBaseClient,
)
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import FormattedInput, GenerateOutput, Logprob, StandardInput
if TYPE_CHECKING:
from huggingface_hub.inference._generated.types.chat_completion import (
ChatCompletionOutput,
ChatCompletionOutputComplete,
)
from huggingface_hub.inference._generated.types.text_generation import (
TextGenerationOutput,
)
from distilabel.typing import Logprob
class InferenceEndpointsLLM(
InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin
):
"""InferenceEndpoints LLM implementation running the async API client.
This LLM will internally use `huggingface_hub.AsyncInferenceClient`.
Attributes:
model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which
will be used to resolve the base URL for the serverless Inference Endpoints API requests.
Defaults to `None`.
endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
base_url: the base URL to use for the Inference Endpoints API requests.
api_key: the API key to authenticate the requests to the Inference Endpoints API.
tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub.
Defaults to `None`, but defining one is recommended to properly format the prompt.
model_display_name: the model display name to use for the LLM. Defaults to `None`.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
structured_output: a dictionary containing the structured output configuration or
if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`.
Defaults to None.
Icon:
`:hugging:`
Examples:
Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:
```python
from distilabel.models.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Dedicated Inference Endpoints:
```python
from distilabel.models.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
endpoint_name="<ENDPOINT_NAME>",
api_key="<HF_API_KEY>",
endpoint_namespace="<USER|ORG>",
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Dedicated Inference Endpoints or TGI:
```python
from distilabel.models.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
api_key="<HF_API_KEY>",
base_url="<BASE_URL>",
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
```python
from pydantic import BaseModel
from distilabel.models.llms import InferenceEndpointsLLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
api_key="api.key",
structured_output={"format": "json", "schema": User.model_json_schema()}
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
```
"""
def load(self) -> None:
# Sets the logger and calls the load method of the BaseClient
self._num_generations_param_supported = False
AsyncLLM.load(self)
InferenceEndpointsBaseClient.load(self)
@model_validator(mode="after") # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
self,
) -> "InferenceEndpointsLLM":
"""Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
favour of the dynamically calculated one.."""
if self.base_url and (self.model_id or self.endpoint_name):
self._logger.warning( # type: ignore
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
" or overwritten with the one generated from either of those args, for serverless"
" or dedicated inference endpoints, respectively."
)
if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)
if (
self.model_id
and self.tokenizer_id is None
and self.structured_output is not None
):
self.tokenizer_id = self.model_id
if self.base_url and not (self.model_id or self.endpoint_name):
return self
if self.model_id and not self.endpoint_name:
return self
if self.endpoint_name and not self.model_id:
return self
raise ValidationError(
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
)
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
conversation=input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
def _get_structured_output(
self, input: FormattedInput
) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
"""Gets the structured output (if any) for the given input.
Args:
input: a single input in chat format to generate responses for.
Returns:
The input and the structured output that will be passed as `grammar` to the
inference endpoint or `None` if not required.
"""
structured_output = None
# Specific structured output per input
if isinstance(input, tuple):
input, structured_output = input
structured_output = {
"type": structured_output["format"], # type: ignore
"value": structured_output["schema"], # type: ignore
}
# Same structured output for all the inputs
if structured_output is None and self.structured_output is not None:
try:
structured_output = {
"type": self.structured_output["format"], # type: ignore
"value": self.structured_output["schema"], # type: ignore
}
except KeyError as e:
raise ValueError(
"To use the structured output you have to inform the `format` and `schema` in "
"the `structured_output` attribute."
) from e
if structured_output:
if isinstance(structured_output["value"], ModelMetaclass):
structured_output["value"] = structured_output[
"value"
].model_json_schema()
return input, structured_output
async def _generate_with_text_generation(
self,
input: str,
max_new_tokens: int = 128,
repetition_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
temperature: float = 1.0,
do_sample: bool = False,
top_n_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
typical_p: Optional[float] = None,
stop_sequences: Union[List[str], None] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
watermark: bool = False,
structured_output: Union[Dict[str, Any], None] = None,
) -> GenerateOutput:
generation: Union["TextGenerationOutput", None] = None
try:
generation = await self._aclient.text_generation( # type: ignore
prompt=input,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
temperature=temperature,
top_n_tokens=top_n_tokens,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
return_full_text=return_full_text,
# NOTE: here to ensure that the cache is not used and a different response is
# generated every time
seed=seed or random.randint(0, sys.maxsize),
watermark=watermark,
grammar=structured_output, # type: ignore
details=True,
)
except Exception as e:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
f" Finish reason was: {e}"
)
return prepare_output(
generations=[generation.generated_text] if generation else [None],
input_tokens=[
compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1
],
output_tokens=[
generation.details.generated_tokens
if generation and generation.details
else 0
],
logprobs=self._get_logprobs_from_text_generation(generation)
if generation
else None, # type: ignore
)
def _get_logprobs_from_text_generation(
self, generation: "TextGenerationOutput"
) -> Union[List[List[List["Logprob"]]], None]:
if generation.details is None or generation.details.top_tokens is None:
return None
return [
[
[
{"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
for top_logprob in token_logprobs
]
for token_logprobs in generation.details.top_tokens
]
]
async def _generate_with_chat_completion(
self,
input: "StandardInput",
max_new_tokens: int = 128,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: bool = False,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: float = 1.0,
tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[PositiveInt] = None,
top_p: Optional[float] = None,
) -> GenerateOutput:
message = None
completion: Union["ChatCompletionOutput", None] = None
output_logprobs = None
try:
completion = await self._aclient.chat_completion( # type: ignore
messages=input, # type: ignore
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
presence_penalty=presence_penalty,
# NOTE: here to ensure that the cache is not used and a different response is
# generated every time
seed=seed or random.randint(0, sys.maxsize),
stop=stop_sequences,
temperature=temperature,
tool_choice=tool_choice, # type: ignore
tool_prompt=tool_prompt,
tools=tools, # type: ignore
top_logprobs=top_logprobs,
top_p=top_p,
)
choice = completion.choices[0] # type: ignore
if (message := choice.message.content) is None:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
f" Finish reason was: {choice.finish_reason}"
)
if choice_logprobs := self._get_logprobs_from_choice(choice):
output_logprobs = [choice_logprobs]
except Exception as e:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
f" Finish reason was: {e}"
)
return prepare_output(
generations=[message],
input_tokens=[completion.usage.prompt_tokens] if completion else None,
output_tokens=[completion.usage.completion_tokens] if completion else None,
logprobs=output_logprobs,
)
def _get_logprobs_from_choice(
self, choice: "ChatCompletionOutputComplete"
) -> Union[List[List["Logprob"]], None]:
if choice.logprobs is None:
return None
return [
[
{"token": top_logprob.token, "logprob": top_logprob.logprob}
for top_logprob in token_logprobs.top_logprobs
]
for token_logprobs in choice.logprobs.content
]
def _check_stop_sequences(
self,
stop_sequences: Optional[Union[str, List[str]]] = None,
) -> Union[List[str], None]:
"""Checks that no more than 4 stop sequences are provided.
Args:
stop_sequences: the stop sequences to be checked.
Returns:
The stop sequences.
"""
if stop_sequences is not None:
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
if len(stop_sequences) > 4:
warnings.warn(
"Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
UserWarning,
stacklevel=2,
)
stop_sequences = stop_sequences[:4]
return stop_sequences
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
max_new_tokens: int = 128,
frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
logit_bias: Optional[List[float]] = None,
logprobs: bool = False,
presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
temperature: float = 1.0,
tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
tool_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[PositiveInt] = None,
top_n_tokens: Optional[PositiveInt] = None,
top_p: Optional[float] = None,
do_sample: bool = False,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
top_k: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
num_generations: int = 1,
) -> GenerateOutput:
"""Generates completions for the given input using the async client. This method
uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
`chat_completion` method will be used only if no `tokenizer_id` has been specified.
Some arguments of this function are specific to the `text_generation` method, while
some others are specific to the `chat_completion` method.
Args:
input: a single input in chat format to generate responses for.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
new tokens based on their existing frequency in the text so far, decreasing
model's likelihood to repeat the same line verbatim. Defauls to `None`.
logit_bias: modify the likelihood of specified tokens appearing in the completion.
This argument is exclusive to the `chat_completion` method and will be used
only if `tokenizer_id` is `None`.
Defaults to `None`.
logprobs: whether to return the log probabilities or not. This argument is exclusive
to the `chat_completion` method and will be used only if `tokenizer_id`
is `None`. Defaults to `False`.
presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
new tokens based on whether they appear in the text so far, increasing the
model likelihood to talk about new topics. This argument is exclusive to
the `chat_completion` method and will be used only if `tokenizer_id` is
`None`. Defauls to `None`.
seed: the seed to use for the generation. Defaults to `None`.
stop_sequences: either a single string or a list of strings containing the sequences
to stop the generation at. Defaults to `None`, but will be set to the
`tokenizer.eos_token` if available.
temperature: the temperature to use for the generation. Defaults to `1.0`.
tool_choice: the name of the tool the model should call. It can be a dictionary
like `{"function_name": "my_tool"}` or "auto". If not provided, then the
model won't use any tool. This argument is exclusive to the `chat_completion`
method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
tool_prompt: A prompt to be appended before the tools. This argument is exclusive
to the `chat_completion` method and will be used only if `tokenizer_id`
is `None`. Defauls to `None`.
tools: a list of tools definitions that the LLM can use.
This argument is exclusive to the `chat_completion` method and will be used
only if `tokenizer_id` is `None`. Defaults to `None`.
top_logprobs: the number of top log probabilities to return per output token
generated. This argument is exclusive to the `chat_completion` method and
will be used only if `tokenizer_id` is `None`. Defaults to `None`.
top_n_tokens: the number of top log probabilities to return per output token
generated. This argument is exclusive of the `text_generation` method and
will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
do_sample: whether to use sampling for the generation. This argument is exclusive
of the `text_generation` method and will be only used if `tokenizer_id` is not
`None`. Defaults to `False`.
repetition_penalty: the repetition penalty to use for the generation. This argument
is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
is not `None`. Defaults to `None`.
return_full_text: whether to return the full text of the completion or just
the generated text. Defaults to `False`, meaning that only the generated
text will be returned. This argument is exclusive of the `text_generation`
method and will be only used if `tokenizer_id` is not `None`.
top_k: the top-k value to use for the generation. This argument is exclusive
of the `text_generation` method and will be only used if `tokenizer_id`
is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
values in TGI.
typical_p: the typical-p value to use for the generation. This argument is exclusive
of the `text_generation` method and will be only used if `tokenizer_id`
is not `None`. Defaults to `None`.
watermark: whether to add the watermark to the generated text. This argument
is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
is not `None`. Defaults to `None`.
num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
the validation succeds.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
stop_sequences = self._check_stop_sequences(stop_sequences)
if isinstance(input, str) or self.tokenizer_id is not None:
structured_output = None
if not isinstance(input, str):
input, structured_output = self._get_structured_output(input)
input = self.prepare_input(input)
return await self._generate_with_text_generation(
input=input,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
typical_p=typical_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
temperature=temperature,
top_n_tokens=top_n_tokens,
top_p=top_p,
top_k=top_k,
stop_sequences=stop_sequences,
return_full_text=return_full_text,
seed=seed,
watermark=watermark,
structured_output=structured_output,
)
return await self._generate_with_chat_completion(
input=input, # type: ignore
max_new_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
presence_penalty=presence_penalty,
seed=seed,
stop_sequences=stop_sequences,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
)
from distilabel.typing import (
GenerateOutput,
OutlinesStructuredOutputType,
StandardInput,
)
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR
if TYPE_CHECKING:
from transformers import Pipeline
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from distilabel.typing import HiddenState
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
"""Hugging Face `transformers` library LLM implementation using the text generation
pipeline.
Attributes:
model: the model Hugging Face Hub repo id or a path to a directory containing the
model weights and configuration files.
revision: if `model` refers to a Hugging Face Hub repository, then the revision
(e.g. a branch name or a commit id) to use. Defaults to `"main"`.
torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc.
Defaults to `"auto"`.
trust_remote_code: whether to allow fetching and executing remote code fetched
from the repository in the Hub. Defaults to `False`.
model_kwargs: additional dictionary of keyword arguments that will be passed to
the `from_pretrained` method of the model.
tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
the tokenizer config files. If not provided, the one associated to the `model`
will be used. Defaults to `None`.
use_fast: whether to use a fast tokenizer or not. Defaults to `True`.
chat_template: a chat template that will be used to build the prompts before
sending them to the model. If not provided, the chat template defined in the
tokenizer config will be used. If not provided and the tokenizer doesn't have
a chat template, then ChatML template will be used. Defaults to `None`.
device: the name or index of the device where the model will be loaded. Defaults
to `None`.
device_map: a dictionary mapping each layer of the model to a device, or a mode
like `"sequential"` or `"auto"`. Defaults to `None`.
token: the Hugging Face Hub token that will be used to authenticate to the Hugging
Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
local configuration will be used. Defaults to `None`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
Icon:
`:hugging:`
Examples:
Generate text:
```python
from distilabel.models.llms import TransformersLLM
llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
model: str
revision: str = "main"
torch_dtype: str = "auto"
trust_remote_code: bool = False
model_kwargs: Optional[Dict[str, Any]] = None
tokenizer: Optional[str] = None
use_fast: bool = True
chat_template: Optional[str] = None
device: Optional[Union[str, int]] = None
device_map: Optional[Union[str, Dict[str, Any]]] = None
token: Optional[SecretStr] = Field(
default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR)
)
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)
_pipeline: Optional["Pipeline"] = PrivateAttr(...)
_prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
_logits_processor: Union[Callable, None] = PrivateAttr(default=None)
def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
it will configure the tokenizer chat template."""
if self.device == "cuda":
CudaDevicePlacementMixin.load(self)
try:
from transformers import pipeline
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
) from ie
token = self.token.get_secret_value() if self.token is not None else self.token
self._pipeline = pipeline(
"text-generation",
model=self.model,
revision=self.revision,
torch_dtype=self.torch_dtype,
trust_remote_code=self.trust_remote_code,
model_kwargs=self.model_kwargs or {},
tokenizer=self.tokenizer or self.model,
use_fast=self.use_fast,
device=self.device,
device_map=self.device_map,
token=token,
return_full_text=False,
)
if self.chat_template is not None:
self._pipeline.tokenizer.chat_template = self.chat_template # type: ignore
if self._pipeline.tokenizer.pad_token is None: # type: ignore
self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore
if self.structured_output:
processor = self._prepare_structured_output(self.structured_output)
if _is_outlines_version_below_0_1_0():
self._prefix_allowed_tokens_fn = processor
else:
self._logits_processor = [processor]
super().load()
def unload(self) -> None:
"""Unloads the `vLLM` model."""
CudaDevicePlacementMixin.unload(self)
super().unload()
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
if self._pipeline.tokenizer.chat_template is None: # type: ignore
return input[0]["content"]
prompt: str = (
self._pipeline.tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
@validate_call
def generate( # type: ignore
self,
inputs: List[StandardInput],
num_generations: int = 1,
max_new_tokens: int = 128,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
top_p: float = 1.0,
top_k: int = 0,
do_sample: bool = True,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input using the text generation
pipeline.
Args:
inputs: a list of inputs in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
repetition_penalty: the repetition penalty to use for the generation. Defaults
to `1.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
top_k: the top-k value to use for the generation. Defaults to `0`.
do_sample: whether to use sampling or not. Defaults to `True`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
prepared_inputs = [self.prepare_input(input=input) for input in inputs]
outputs: List[List[Dict[str, str]]] = self._pipeline( # type: ignore
prepared_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
num_return_sequences=num_generations,
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
pad_token_id=self._pipeline.tokenizer.eos_token_id,
logits_processor=self._logits_processor,
)
llm_output = [
[generation["generated_text"] for generation in output]
for output in outputs
]
result = []
for input, output in zip(inputs, llm_output):
result.append(
prepare_output(
output,
input_tokens=[
compute_tokens(input, self._pipeline.tokenizer.encode)
],
output_tokens=[
compute_tokens(row, self._pipeline.tokenizer.encode)
for row in output
],
)
)
return result
def get_last_hidden_states(
self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
"""Gets the last `hidden_states` of the model for the given inputs. It doesn't
execute the task head.
Args:
inputs: a list of inputs in chat format to generate the embeddings for.
Returns:
A list containing the last hidden state for each sequence using a NumPy array
with shape [num_tokens, hidden_size].
"""
model: "PreTrainedModel" = (
self._pipeline.model.model # type: ignore
if hasattr(self._pipeline.model, "model") # type: ignore
else next(self._pipeline.model.children()) # type: ignore
)
tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer # type: ignore
input_ids = tokenizer(
[self.prepare_input(input) for input in inputs], # type: ignore
return_tensors="pt",
padding=True,
).to(model.device)
last_hidden_states = model(**input_ids)["last_hidden_state"]
return [
seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
for seq_last_hidden_state, attention_mask in zip(
last_hidden_states,
input_ids["attention_mask"], # type: ignore
)
]
def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, List[Callable]]:
"""Creates the appropriate function to filter tokens to generate structured outputs.
Args:
structured_output: the configuration dict to prepare the structured output.
Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)
result = prepare_guided_output(
structured_output, "transformers", self._pipeline
)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import orjson
from pydantic import Field, PrivateAttr, validate_call
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import prepare_output
from distilabel.typing import (
FormattedInput,
GenerateOutput,
InstructorStructuredOutputType,
)
if TYPE_CHECKING:
from litellm import Choices
from litellm.types.utils import ModelResponse
from pydantic import BaseModel
class LiteLLM(AsyncLLM):
"""LiteLLM implementation running the async API client.
Attributes:
model: the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large",
etc.
verbose: whether to log the LiteLLM client's logs. Defaults to `False`.
structured_output: a dictionary containing the structured output configuration configuration
using `instructor`. You can take a look at the dictionary structure in
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
Runtime parameters:
- `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`.
Examples:
Generate text:
```python
from distilabel.models.llms import LiteLLM
llm = LiteLLM(model="gpt-3.5-turbo")
llm.load()
# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
Generate structured data:
```python
from pydantic import BaseModel
from distilabel.models.llms import LiteLLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = LiteLLM(
model="gpt-3.5-turbo",
api_key="api.key",
structured_output={"schema": User}
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
model: str
verbose: RuntimeParameter[bool] = Field(
default=False, description="Whether to log the LiteLLM client's logs."
)
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
Field(
default=None,
description="The structured output format to use across all the generations.",
)
)
_aclient: Optional[Callable] = PrivateAttr(...)
def load(self) -> None:
"""
Loads the `acompletion` LiteLLM client to benefit from async requests.
"""
super().load()
try:
import litellm
litellm.telemetry = False
except ImportError as e:
raise ImportError(
"LiteLLM Python client is not installed. Please install it using"
" `pip install 'distilabel[litellm]'`."
) from e
self._aclient = litellm.acompletion
if not self.verbose:
litellm.suppress_debug_info = True
for key in logging.Logger.manager.loggerDict.keys():
if "litellm" not in key.lower():
continue
logging.getLogger(key).setLevel(logging.CRITICAL)
if self.structured_output:
result = self._prepare_structured_output(
structured_output=self.structured_output,
client=self._aclient,
framework="litellm",
)
self._aclient = result.get("client").messages.create
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
@validate_call
async def agenerate( # type: ignore # noqa: C901
self,
input: FormattedInput,
num_generations: int = 1,
functions: Optional[List] = None,
function_call: Optional[str] = None,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
stop: Optional[Union[str, list]] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
metadata: Optional[dict] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None,
mock_response: Optional[str] = None,
force_timeout: Optional[int] = 600,
custom_llm_provider: Optional[str] = None,
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).
Args:
input: a single input in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
functions: a list of functions to apply to the conversation messages. Defaults to
`None`.
function_call: the name of the function to call within the conversation. Defaults
to `None`.
temperature: the temperature to use for the generation. Defaults to `1.0`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
stop: Up to 4 sequences where the LLM API will stop generating further tokens.
Defaults to `None`.
max_tokens: The maximum number of tokens in the generated completion. Defaults to
`None`.
presence_penalty: It is used to penalize new tokens based on their existence in the
text so far. Defaults to `None`.
frequency_penalty: It is used to penalize new tokens based on their frequency in the
text so far. Defaults to `None`.
logit_bias: Used to modify the probability of specific tokens appearing in the
completion. Defaults to `None`.
user: A unique identifier representing your end-user. This can help the LLM provider
to monitor and detect abuse. Defaults to `None`.
metadata: Pass in additional metadata to tag your completion calls - eg. prompt
version, details, etc. Defaults to `None`.
api_base: Base URL for the API. Defaults to `None`.
api_version: API version. Defaults to `None`.
api_key: API key. Defaults to `None`.
model_list: List of api base, version, keys. Defaults to `None`.
mock_response: If provided, return a mock completion response for testing or debugging
purposes. Defaults to `None`.
force_timeout: The maximum execution time in seconds for the completion request.
Defaults to `600`.
custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
`None`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
import litellm
from litellm import token_counter
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
result = self._prepare_structured_output(
structured_output=structured_output,
client=self._aclient,
framework="litellm",
)
self._aclient = result.get("client").messages.create
if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
kwargs = {
"model": self.model,
"messages": input,
"n": num_generations,
"functions": functions,
"function_call": function_call,
"temperature": temperature,
"top_p": top_p,
"stream": False,
"stop": stop,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"user": user,
"metadata": metadata,
"api_base": api_base,
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
"mock_response": mock_response,
"force_timeout": force_timeout,
"custom_llm_provider": custom_llm_provider,
}
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)
async def _call_aclient_until_n_choices() -> List["Choices"]:
choices = []
while len(choices) < num_generations:
completion: Union["ModelResponse", "BaseModel"] = await self._aclient(
**kwargs
) # type: ignore
if self.structured_output:
# Prevent pydantic model from being cast to list during list extension
completion = [completion]
else:
completion = completion.choices
choices.extend(completion)
return choices
# litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
try:
litellm.drop_params = False
choices = await _call_aclient_until_n_choices()
except litellm.exceptions.APIError as e:
if "does not support parameters" in str(e):
litellm.drop_params = True
choices = await _call_aclient_until_n_choices()
else:
raise e
generations = []
input_tokens = [
token_counter(model=self.model, messages=input)
] * num_generations
output_tokens = []
if self.structured_output:
for choice in choices:
generations.append(choice.model_dump_json())
output_tokens.append(
token_counter(
model=self.model,
text=orjson.dumps(choice.model_dump_json()).decode("utf-8"),
)
)
return prepare_output(
generations,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
for choice in choices:
if (content := choice.message.content) is None:
self._logger.warning( # type: ignore
f"Received no response using LiteLLM client (model: '{self.model}')."
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
output_tokens.append(token_counter(model=self.model, text=content))
return prepare_output(
generations, input_tokens=input_tokens, output_tokens=output_tokens
)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import Field, FilePath, PrivateAttr, model_validator, validate_call
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.utils import prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import (
FormattedInput,
GenerateOutput,
OutlinesStructuredOutputType,
)
if TYPE_CHECKING:
from llama_cpp import (
CreateChatCompletionResponse,
Llama,
LogitsProcessor,
LogitsProcessorList,
)
from distilabel.typing import FormattedInput, StandardInput
class LlamaCppLLM(LLM, MagpieChatTemplateMixin):
"""llama.cpp LLM implementation running the Python bindings for the C++ code.
Attributes:
model_path: contains the path to the GGUF quantized model, compatible with the
installed version of the `llama.cpp` Python bindings.
n_gpu_layers: the number of layers to use for the GPU. Defaults to `-1`, meaning that
the available GPU device will be used.
chat_format: the chat format to use for the model. Defaults to `None`, which means the
Llama format will be used.
n_ctx: the context size to use for the model. Defaults to `512`.
n_batch: the prompt processing maximum batch size to use for the model. Defaults to `512`.
seed: random seed to use for the generation. Defaults to `4294967295`.
verbose: whether to print verbose output. Defaults to `False`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
`Llama` class of `llama_cpp` library. Defaults to `{}`.
tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
the tokenizer config files. If not provided, the one associated to the `model`
will be used. Defaults to `None`.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
_model: the Llama model instance. This attribute is meant to be used internally and
should not be accessed directly. It will be set in the `load` method.
Runtime parameters:
- `model_path`: the path to the GGUF quantized model.
- `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`.
- `chat_format`: the chat format to use for the model. Defaults to `None`.
- `verbose`: whether to print verbose output. Defaults to `False`.
- `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the
`Llama` class of `llama_cpp` library. Defaults to `{}`.
References:
- [`llama.cpp`](https://github.com/ggerganov/llama.cpp)
- [`llama-cpp-python`](https://github.com/abetlen/llama-cpp-python)
Examples:
Generate text:
```python
from pathlib import Path
from distilabel.models.llms import LlamaCppLLM
# You can follow along this example downloading the following model running the following
# command in the terminal, that will download the model to the `Downloads` folder:
# curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf
model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"
llm = LlamaCppLLM(
model_path=str(Path.home() / model_path),
n_gpu_layers=-1, # To use the GPU if available
n_ctx=1024, # Set the context size
)
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
```python
from pathlib import Path
from distilabel.models.llms import LlamaCppLLM
model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"
class User(BaseModel):
name: str
last_name: str
id: int
llm = LlamaCppLLM(
model_path=str(Path.home() / model_path), # type: ignore
n_gpu_layers=-1,
n_ctx=1024,
structured_output={"format": "json", "schema": Character},
)
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
model_path: RuntimeParameter[FilePath] = Field(
default=None, description="The path to the GGUF quantized model.", exclude=True
)
n_gpu_layers: RuntimeParameter[int] = Field(
default=-1,
description="The number of layers that will be loaded in the GPU.",
)
chat_format: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The chat format to use for the model. Defaults to `None`, which means the Llama format will be used.",
)
n_ctx: int = 512
n_batch: int = 512
seed: int = 4294967295
verbose: RuntimeParameter[bool] = Field(
default=False,
description="Whether to print verbose output from llama.cpp library.",
)
extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
default_factory=dict,
description="Additional dictionary of keyword arguments that will be passed to the"
" `Llama` class of `llama_cpp` library. See all the supported arguments at: "
"https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__",
)
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)
tokenizer_id: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files. If not provided, the one associated to the `model`"
" will be used.",
)
_logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None)
_model: Optional["Llama"] = PrivateAttr(...)
@model_validator(mode="after")
def validate_magpie_usage(
self,
) -> "LlamaCppLLM":
"""Validates that magpie usage is valid."""
if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)
def load(self) -> None:
"""Loads the `Llama` model from the `model_path`."""
try:
from llama_cpp import Llama
except ImportError as ie:
raise ImportError(
"The `llama_cpp` package is required to use the `LlamaCppLLM` class."
) from ie
self._model = Llama(
model_path=self.model_path.as_posix(),
seed=self.seed,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
chat_format=self.chat_format,
n_gpu_layers=self.n_gpu_layers,
verbose=self.verbose,
**self.extra_kwargs,
)
if self.structured_output:
self._logits_processor = self._prepare_structured_output(
self.structured_output
)
if self.use_magpie_template or self.magpie_pre_query_template:
if not self.tokenizer_id:
raise ValueError(
"The Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files is required when using the `use_magpie_template`"
" or `magpie_pre_query_template` runtime parameters."
)
if self.tokenizer_id:
try:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
if self._tokenizer.chat_template is None:
raise ValueError(
"The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
)
# NOTE: Here because of the custom `logging` interface used, since it will create the logging name
# out of the model name, which won't be available until the `Llama` instance is created.
super().load()
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self._model.model_path # type: ignore
def _generate_chat_completion(
self,
input: FormattedInput,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> "CreateChatCompletionResponse":
return self._model.create_chat_completion( # type: ignore
messages=input, # type: ignore
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
logits_processor=self._logits_processor,
**(extra_generation_kwargs or {}),
)
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
conversation=input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
def _generate_with_text_generation(
self,
input: FormattedInput,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> "CreateChatCompletionResponse":
prompt = self.prepare_input(input)
return self._model.create_completion(
prompt=prompt,
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
logits_processor=self._logits_processor,
**(extra_generation_kwargs or {}),
)
@validate_call
def generate( # type: ignore
self,
inputs: List[FormattedInput],
num_generations: int = 1,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for the given input using the Llama model.
Args:
inputs: a list of inputs in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
extra_generation_kwargs: dictionary with additional arguments to be passed to
the `create_chat_completion` method. Reference at
https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion
Returns:
A list of lists of strings containing the generated responses for each input.
"""
structured_output = None
batch_outputs = []
for input in inputs:
if isinstance(input, tuple):
input, structured_output = input
elif self.structured_output:
structured_output = self.structured_output
outputs = []
output_tokens = []
for _ in range(num_generations):
# NOTE(plaguss): There seems to be a bug in how the logits processor
# is used. Basically it consumes the FSM internally, and it isn't reinitialized
# after each generation, so subsequent calls yield nothing. This is a workaround
# until is fixed in the `llama_cpp` or `outlines` libraries.
if structured_output:
self._logits_processor = self._prepare_structured_output(
structured_output
)
if self.tokenizer_id is None:
completion = self._generate_chat_completion(
input,
max_new_tokens,
frequency_penalty,
presence_penalty,
temperature,
top_p,
extra_generation_kwargs,
)
outputs.append(completion["choices"][0]["message"]["content"])
output_tokens.append(completion["usage"]["completion_tokens"])
else:
completion: "CreateChatCompletionResponse" = (
self._generate_with_text_generation( # type: ignore
input,
max_new_tokens,
frequency_penalty,
presence_penalty,
temperature,
top_p,
extra_generation_kwargs,
)
)
outputs.append(completion["choices"][0]["text"])
output_tokens.append(completion["usage"]["completion_tokens"])
batch_outputs.append(
prepare_output(
outputs,
input_tokens=[completion["usage"]["prompt_tokens"]]
* num_generations,
output_tokens=output_tokens,
)
)
return batch_outputs
def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union["LogitsProcessorList", "LogitsProcessor"]:
"""Creates the appropriate function to filter tokens to generate structured outputs.
Args:
structured_output: the configuration dict to prepare the structured output.
Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)
result = prepare_guided_output(structured_output, "llamacpp", self._model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema
return [result["processor"]]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Optional
from pydantic import Field, PrivateAttr, SecretStr, validate_call
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import prepare_output
from distilabel.typing import (
FormattedInput,
GenerateOutput,
InstructorStructuredOutputType,
)
if TYPE_CHECKING:
from mistralai import Mistral
from mistralai.models.chatcompletionresponse import ChatCompletionResponse
from distilabel.typing import LLMStatistics
_MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY"
class MistralLLM(AsyncLLM):
"""Mistral LLM implementation running the async API client.
Attributes:
model: the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.
endpoint: the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".
api_key: the API key to authenticate the requests to the Mistral API. Defaults to `None` which
means that the value set for the environment variable `OPENAI_API_KEY` will be used, or
`None` if not set.
max_retries: the maximum number of retries to attempt when a request fails. Defaults to `5`.
timeout: the maximum time in seconds to wait for a response. Defaults to `120`.
max_concurrent_requests: the maximum number of concurrent requests to send. Defaults
to `64`.
structured_output: a dictionary containing the structured output configuration configuration
using `instructor`. You can take a look at the dictionary structure in
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
_api_key_env_var: the name of the environment variable to use for the API key. It is meant to
be used internally.
_aclient: the `Mistral` to use for the Mistral API. It is meant to be used internally.
Set in the `load` method.
Runtime parameters:
- `api_key`: the API key to authenticate the requests to the Mistral API.
- `max_retries`: the maximum number of retries to attempt when a request fails.
Defaults to `5`.
- `timeout`: the maximum time in seconds to wait for a response. Defaults to `120`.
- `max_concurrent_requests`: the maximum number of concurrent requests to send.
Defaults to `64`.
Examples:
Generate text:
```python
from distilabel.models.llms import MistralLLM
llm = MistralLLM(model="open-mixtral-8x22b")
llm.load()
# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
Generate structured data:
```python
from pydantic import BaseModel
from distilabel.models.llms import MistralLLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = MistralLLM(
model="open-mixtral-8x22b",
api_key="api.key",
structured_output={"schema": User}
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
model: str
endpoint: str = "https://api.mistral.ai"
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_MISTRALAI_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Mistral API.",
)
max_retries: RuntimeParameter[int] = Field(
default=6,
description="The maximum number of times to retry the request to the API before"
" failing.",
)
timeout: RuntimeParameter[int] = Field(
default=120,
description="The maximum time in seconds to wait for a response from the API.",
)
max_concurrent_requests: RuntimeParameter[int] = Field(
default=64, description="The maximum number of concurrent requests to send."
)
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
Field(
default=None,
description="The structured output format to use across all the generations.",
)
)
_num_generations_param_supported = False
_api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
_aclient: Optional["Mistral"] = PrivateAttr(...)
def load(self) -> None:
"""Loads the `Mistral` client to benefit from async requests."""
super().load()
try:
from mistralai import Mistral
except ImportError as ie:
raise ImportError(
"MistralAI Python client is not installed. Please install it using"
" `pip install 'distilabel[mistralai]'`."
) from ie
if self.api_key is None:
raise ValueError(
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)
self._aclient = Mistral(
api_key=self.api_key.get_secret_value(),
endpoint=self.endpoint,
max_retries=self.max_retries, # type: ignore
timeout=self.timeout, # type: ignore
max_concurrent_requests=self.max_concurrent_requests, # type: ignore
)
if self.structured_output:
result = self._prepare_structured_output(
structured_output=self.structured_output,
client=self._aclient,
framework="mistral",
)
self._aclient = result.get("client") # type: ignore
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
# TODO: add `num_generations` parameter once Mistral client allows `n` parameter
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the MistralAI async
client.
Args:
input: a single input in chat format to generate responses for.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
result = self._prepare_structured_output(
structured_output=structured_output,
client=self._aclient,
framework="mistral",
)
self._aclient = result.get("client")
if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
kwargs = {
"messages": input, # type: ignore
"model": self.model,
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
}
generations = []
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)
# TODO: This should work just with the _aclient.chat method, but it's not working.
# We need to check instructor and see if we can create a PR.
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
else:
# completion = await self._aclient.chat(**kwargs) # type: ignore
completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore
if structured_output:
return prepare_output(
[completion.model_dump_json()],
**self._get_llm_statistics(completion._raw_response),
)
for choice in completion.choices:
if (content := choice.message.content) is None:
self._logger.warning( # type: ignore
f"Received no response using MistralAI client (model: '{self.model}')."
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
return prepare_output(generations, **self._get_llm_statistics(completion))
@staticmethod
def _get_llm_statistics(completion: "ChatCompletionResponse") -> "LLMStatistics":
return {
"input_tokens": [completion.usage.prompt_tokens],
"output_tokens": [completion.usage.completion_tokens],
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Union,
)
from pydantic import (
Field,
PrivateAttr,
validate_call,
)
from distilabel.models.llms.base import LLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import GenerateOutput, StandardInput
if TYPE_CHECKING:
import mlx.nn as nn
from mlx_lm.tokenizer_utils import TokenizerWrapper
class MlxLLM(LLM, MagpieChatTemplateMixin):
"""Apple MLX LLM implementation.
Attributes:
path_or_hf_repo: the path to the model or the Hugging Face Hub repo id.
tokenizer_config: the tokenizer configuration.
mlx_model_config: the MLX model configuration.
adapter_path: the path to the adapter.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
Icon:
`:apple:`
Examples:
Generate text:
```python
from distilabel.models.llms import MlxLLM
llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
path_or_hf_repo: str
tokenizer_config: Dict[str, Any] = Field(default_factory=dict)
mlx_model_config: Dict[str, Any] = Field(default_factory=dict)
adapter_path: Optional[str] = None
_model: Optional["nn.Module"] = PrivateAttr(None)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None)
_mlx_generate: Optional[Callable] = PrivateAttr(None)
_make_sampler: Optional[Callable] = PrivateAttr(None)
def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
it will configure the tokenizer chat template."""
try:
import mlx # noqa
from mlx_lm.utils import generate, load
from mlx_lm.sample_utils import make_sampler
except ImportError as ie:
raise ImportError(
"MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
) from ie
self._model, self._tokenizer = load(
self.path_or_hf_repo,
tokenizer_config=self.tokenizer_config,
model_config=self.mlx_model_config,
adapter_path=self.adapter_path,
)
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
self._mlx_generate = generate
self._make_sampler = make_sampler
super().load()
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.path_or_hf_repo
def prepare_input(self, input: Union["StandardInput", str]) -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
if isinstance(input, str):
return input
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
input,
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
@validate_call
def generate( # type: ignore
self,
inputs: List[Union[StandardInput, str]],
num_generations: int = 1,
max_tokens: int = 256,
logits_processors: Optional[List[Callable]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
top_k: int = -1,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input using the text generation
pipeline.
Args:
inputs: the inputs to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
logits_processors: the logits processors to use for the generation. Defaults to
`None`.
max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
prompt_cache: the prompt cache to use for the generation. Defaults to `None`.
prefill_step_size: the prefill step size. Defaults to `512`.
kv_bits: the number of bits to use for the key-value cache. Defaults to `None`.
kv_group_size: the group size for the key-value cache. Defaults to `64`.
quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
prompt_progress_callback: the callback to use for the generation. Defaults to
`None`.
temp: The temperature for text generation. Defaults to `0.0`.
top_p: The top-p value used for the generation. Defaults to `0.0`.
min_p: The min-p value used for the generation. Defaults to `0.0`.
min_tokens_to_keep: Minimum number of tokens to keep for sampling after
filtering. Must be at least 1. Defaults to `1`.
top_k: The top-k value used for the generation. Defaults to `-1`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
sampler = self._make_sampler( # type: ignore
temp=temp,
top_p=top_p,
min_p=min_p,
min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
)
structured_output = None
result = []
for input in inputs:
if isinstance(input, tuple):
input, structured_output = input
output: List[str] = []
for _ in range(num_generations):
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)
prompt = self.prepare_input(input)
generation = self._mlx_generate( # type: ignore
prompt=prompt,
model=self._model,
tokenizer=self._tokenizer,
logits_processors=logits_processors,
max_tokens=max_tokens,
sampler=sampler,
max_kv_size=max_kv_size,
prompt_cache=prompt_cache,
prefill_step_size=prefill_step_size,
kv_bits=kv_bits,
kv_group_size=kv_group_size,
quantized_kv_start=quantized_kv_start,
prompt_progress_callback=prompt_progress_callback,
)
output.append(generation)
result.append(
prepare_output(
generations=output,
input_tokens=[compute_tokens(input, self._tokenizer.encode)], # type: ignore
output_tokens=[
compute_tokens(
text_or_messages=generation,
tokenizer=self._tokenizer.encode, # type: ignore
)
for generation in output
],
)
)
return result
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import itertools
from typing import TYPE_CHECKING, Any, Dict, List, Union, cast
from pydantic import Field
from distilabel.models.llms.base import LLM, AsyncLLM
from distilabel.typing import StandardInput
if TYPE_CHECKING:
from distilabel.mixins.runtime_parameters import RuntimeParametersNames
from distilabel.typing import FormattedInput, GenerateOutput
# Mixture-of-Agents system prompt from the paper with the addition instructing the LLM
# to not mention that it used responses from previous models to avoid having texts like
# "Based on the previous responses..." in the completion.
MOA_SYSTEM_PROMPT = (
"You have been provided with a set of responses from various open-source models to the"
" latest user query. Your task is to synthesize these responses into a single, high-quality"
" response. It is crucial to critically evaluate the information provided in these responses,"
" recognizing that some of it may be biased or incorrect. Your response should not simply"
" replicate the given answers but should offer a refined, accurate, and comprehensive"
" reply to the instruction. Ensure your response is well-structured, coherent, and adheres"
" to the highest standards of accuracy and reliability. Do not mention that you have used"
" the responses from previous models."
"\nResponses from models:"
)
class MixtureOfAgentsLLM(AsyncLLM):
"""`Mixture-of-Agents` implementation.
An `LLM` class that leverages `LLM`s collective strenghts to generate a response,
as described in the "Mixture-of-Agents Enhances Large Language model Capabilities"
paper. There is a list of `LLM`s proposing/generating outputs that `LLM`s from the next
round/layer can use as auxiliary information. Finally, there is an `LLM` that aggregates
the outputs to generate the final response.
Attributes:
aggregator_llm: The `LLM` that aggregates the outputs of the proposer `LLM`s.
proposers_llms: The list of `LLM`s that propose outputs to be aggregated.
rounds: The number of layers or rounds that the `proposers_llms` will generate
outputs. Defaults to `1`.
References:
- [Mixture-of-Agents Enhances Large Language Model Capabilities](https://arxiv.org/abs/2406.04692)
Examples:
Generate text:
```python
from distilabel.models.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM
llm = MixtureOfAgentsLLM(
aggregator_llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
proposers_llms=[
InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
),
InferenceEndpointsLLM(
model_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
tokenizer_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
),
InferenceEndpointsLLM(
model_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
tokenizer_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
),
],
rounds=2,
)
llm.load()
output = llm.generate_outputs(
inputs=[
[
{
"role": "user",
"content": "My favorite witty review of The Rings of Power series is this: Input:",
}
]
]
)
```
"""
aggregator_llm: LLM
proposers_llms: List[AsyncLLM] = Field(default_factory=list)
rounds: int = 1
@property
def runtime_parameters_names(self) -> "RuntimeParametersNames":
"""Returns the runtime parameters of the `LLM`, which are a combination of the
`RuntimeParameter`s of the `LLM`, the `aggregator_llm` and the `proposers_llms`.
Returns:
The runtime parameters of the `LLM`.
"""
runtime_parameters_names = super().runtime_parameters_names
del runtime_parameters_names["generation_kwargs"]
return runtime_parameters_names
def load(self) -> None:
"""Loads all the `LLM`s in the `MixtureOfAgents`."""
super().load()
for llm in self.proposers_llms:
self._logger.debug(f"Loading proposer LLM in MoA: {llm}") # type: ignore
llm.load()
self._logger.debug(f"Loading aggregator LLM in MoA: {self.aggregator_llm}") # type: ignore
self.aggregator_llm.load()
@property
def model_name(self) -> str:
"""Returns the aggregated model name."""
return f"moa-{self.aggregator_llm.model_name}-{'-'.join([llm.model_name for llm in self.proposers_llms])}"
def get_generation_kwargs(self) -> Dict[str, Any]:
"""Returns the generation kwargs of the `MixtureOfAgents` as a dictionary.
Returns:
The generation kwargs of the `MixtureOfAgents`.
"""
return {
"aggregator_llm": self.aggregator_llm.get_generation_kwargs(),
"proposers_llms": [
llm.get_generation_kwargs() for llm in self.proposers_llms
],
}
# `abstractmethod`, had to be implemented but not used
async def agenerate(
self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
) -> List[Union[str, None]]:
raise NotImplementedError(
"`agenerate` method is not implemented for `MixtureOfAgents`"
)
def _build_moa_system_prompt(self, prev_outputs: List[str]) -> str:
"""Builds the Mixture-of-Agents system prompt.
Args:
prev_outputs: The list of previous outputs to use as references.
Returns:
The Mixture-of-Agents system prompt.
"""
moa_system_prompt = MOA_SYSTEM_PROMPT
for i, prev_output in enumerate(prev_outputs):
if prev_output is not None:
moa_system_prompt += f"\n{i + 1}. {prev_output}"
return moa_system_prompt
def _inject_moa_system_prompt(
self, input: "StandardInput", prev_outputs: List[str]
) -> "StandardInput":
"""Injects the Mixture-of-Agents system prompt into the input.
Args:
input: The input to inject the system prompt into.
prev_outputs: The list of previous outputs to use as references.
Returns:
The input with the Mixture-of-Agents system prompt injected.
"""
if len(prev_outputs) == 0:
return input
moa_system_prompt = self._build_moa_system_prompt(prev_outputs)
system = next((item for item in input if item["role"] == "system"), None)
if system:
original_system_prompt = system["content"]
system["content"] = f"{moa_system_prompt}\n\n{original_system_prompt}"
else:
input.insert(0, {"role": "system", "content": moa_system_prompt})
return input
async def _agenerate(
self,
inputs: List["FormattedInput"],
num_generations: int = 1,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Internal function to concurrently generate responses for a list of inputs.
Args:
inputs: the list of inputs to generate responses for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the generations for each input.
"""
aggregator_llm_kwargs: Dict[str, Any] = kwargs.get("aggregator_llm", {})
proposers_llms_kwargs: List[Dict[str, Any]] = kwargs.get(
"proposers_llms", [{}] * len(self.proposers_llms)
)
prev_outputs = []
for round in range(self.rounds):
self._logger.debug(f"Generating round {round + 1}/{self.rounds} in MoA") # type: ignore
# Generate `num_generations` with each proposer LLM for each input
tasks = [
asyncio.create_task(
llm._agenerate(
inputs=[
self._inject_moa_system_prompt(
cast("StandardInput", input), prev_input_outputs
)
for input, prev_input_outputs in itertools.zip_longest(
inputs, prev_outputs, fillvalue=[]
)
],
num_generations=1,
**generation_kwargs,
)
)
for llm, generation_kwargs in zip(
self.proposers_llms, proposers_llms_kwargs
)
]
# Group generations per input
outputs: List[List["GenerateOutput"]] = await asyncio.gather(*tasks)
prev_outputs = [
list(itertools.chain(*input_outputs)) for input_outputs in zip(*outputs)
]
self._logger.debug("Aggregating outputs in MoA") # type: ignore
if isinstance(self.aggregator_llm, AsyncLLM):
return await self.aggregator_llm._agenerate(
inputs=[
self._inject_moa_system_prompt(
cast("StandardInput", input), prev_input_outputs
)
for input, prev_input_outputs in zip(inputs, prev_outputs)
],
num_generations=num_generations,
**aggregator_llm_kwargs,
)
return self.aggregator_llm.generate(
inputs=[
self._inject_moa_system_prompt(
cast("StandardInput", input), prev_input_outputs
)
for input, prev_input_outputs in zip(inputs, prev_outputs)
],
num_generations=num_generations,
**aggregator_llm_kwargs,
)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union
from pydantic import Field, PrivateAttr, model_validator, validate_call
from typing_extensions import TypedDict
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import (
GenerateOutput,
InstructorStructuredOutputType,
StandardInput,
)
if TYPE_CHECKING:
from ollama import AsyncClient
from ollama._types import ChatResponse, GenerateResponse
from distilabel.typing import LLMStatistics, StandardInput
# Copied from `ollama._types.Options`
class Options(TypedDict, total=False):
# load time options
numa: bool
num_ctx: int
num_batch: int
num_gqa: int
num_gpu: int
main_gpu: int
low_vram: bool
f16_kv: bool
logits_all: bool
vocab_only: bool
use_mmap: bool
use_mlock: bool
embedding_only: bool
rope_frequency_base: float
rope_frequency_scale: float
num_thread: int
# runtime options
num_keep: int
seed: int
num_predict: int
top_k: int
top_p: float
tfs_z: float
typical_p: float
repeat_last_n: int
temperature: float
repeat_penalty: float
presence_penalty: float
frequency_penalty: float
mirostat: int
mirostat_tau: float
mirostat_eta: float
penalize_newline: bool
stop: Sequence[str]
class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin):
"""Ollama LLM implementation running the Async API client.
Attributes:
model: the model name to use for the LLM e.g. "notus".
host: the Ollama server host.
timeout: the timeout for the LLM. Defaults to `120`.
follow_redirects: whether to follow redirects. Defaults to `True`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
the tokenizer config files. If not provided, the one associated to the `model`
will be used. Defaults to `None`.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
_aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally.
Set in the `load` method.
Runtime parameters:
- `host`: the Ollama server host.
- `timeout`: the client timeout for the Ollama API. Defaults to `120`.
Examples:
Generate text:
```python
from distilabel.models.llms import OllamaLLM
llm = OllamaLLM(model="llama3")
llm.load()
# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
model: str
host: Optional[RuntimeParameter[str]] = Field(
default=None, description="The host of the Ollama API."
)
timeout: RuntimeParameter[int] = Field(
default=120, description="The timeout for the Ollama API."
)
follow_redirects: bool = True
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
Field(
default=None,
description="The structured output format to use across all the generations.",
)
)
tokenizer_id: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The Hugging Face Hub repo id or a path to a directory containing"
" the tokenizer config files. If not provided, the one associated to the `model`"
" will be used.",
)
_num_generations_param_supported = False
_aclient: Optional["AsyncClient"] = PrivateAttr(...) # type: ignore
@model_validator(mode="after") # type: ignore
def validate_magpie_usage(
self,
) -> "OllamaLLM":
"""Validates that magpie usage is valid."""
if self.use_magpie_template and self.tokenizer_id is None:
raise ValueError(
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
" set a `tokenizer_id` and try again."
)
def load(self) -> None:
"""Loads the `AsyncClient` to use Ollama async API."""
super().load()
try:
from ollama import AsyncClient
self._aclient = AsyncClient(
host=self.host,
timeout=self.timeout,
follow_redirects=self.follow_redirects,
)
except ImportError as e:
raise ImportError(
"Ollama Python client is not installed. Please install it using"
" `pip install 'distilabel[ollama]'`."
) from e
if self.tokenizer_id:
try:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
if self._tokenizer.chat_template is None:
raise ValueError(
"The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
)
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
async def _generate_chat_completion(
self,
input: "StandardInput",
format: Literal["", "json"] = "",
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> "ChatResponse":
return await self._aclient.chat(
model=self.model,
messages=input,
stream=False,
format=format,
options=options,
keep_alive=keep_alive,
)
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template(
conversation=input,
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
async def _generate_with_text_generation(
self,
input: "StandardInput",
format: Literal["", "json"] = None,
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> "GenerateResponse":
input = self.prepare_input(input)
return await self._aclient.generate(
model=self.model,
prompt=input,
format=format,
options=options,
keep_alive=keep_alive,
raw=True,
)
@validate_call
async def agenerate(
self,
input: StandardInput,
format: Literal["", "json"] = "",
# TODO: include relevant options from `Options` in `agenerate` method.
options: Union[Options, None] = None,
keep_alive: Union[bool, None] = None,
) -> GenerateOutput:
"""
Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).
Args:
input: the input to use for the generation.
format: the format to use for the generation. Defaults to `""`.
options: the options to use for the generation. Defaults to `None`.
keep_alive: whether to keep the connection alive. Defaults to `None`.
Returns:
A list of strings as completion for the given input.
"""
text = None
try:
if not format:
format = None
if self.tokenizer_id is None:
completion = await self._generate_chat_completion(
input, format, options, keep_alive
)
text = completion["message"]["content"]
else:
completion = await self._generate_with_text_generation(
input, format, options, keep_alive
)
text = completion.response
except Exception as e:
self._logger.warning( # type: ignore
f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
f" Finish reason was: {e}"
)
return prepare_output([text], **self._get_llm_statistics(completion))
@staticmethod
def _get_llm_statistics(completion: Dict[str, Any]) -> "LLMStatistics":
return {
"input_tokens": [completion["prompt_eval_count"]],
"output_tokens": [completion["eval_count"]],
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
import orjson
from pydantic import NonNegativeInt, PositiveInt, validate_call
from distilabel import envs
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.models.base_clients.openai import OpenAIBaseClient
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import prepare_output
from distilabel.typing import FormattedInput, GenerateOutput
if TYPE_CHECKING:
from openai.types import Batch as OpenAIBatch
from openai.types import FileObject as OpenAIFileObject
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from openai.types.chat.chat_completion import Choice as OpenAIChatCompletionChoice
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import (
CompletionChoice as OpenAICompletionChoice,
)
from distilabel.typing.models import (
LLMStatistics,
Logprob,
StandardInput,
StructuredInput,
)
_OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
class OpenAILLM(OpenAIBaseClient, AsyncLLM):
"""OpenAI LLM implementation running the async API client.
Attributes:
model: the model name to use for the LLM e.g. "gpt-3.5-turbo", "gpt-4", etc.
Supported models can be found [here](https://platform.openai.com/docs/guides/text-generation).
base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
means that the value set for the environment variable `OPENAI_BASE_URL` will
be used, or "https://api.openai.com/v1" if not set.
api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
`None` which means that the value set for the environment variable `OPENAI_API_KEY`
will be used, or `None` if not set.
default_headers: the default headers to use for the OpenAI API requests.
max_retries: the maximum number of times to retry the request to the API before
failing. Defaults to `6`.
timeout: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
structured_output: a dictionary containing the structured output configuration configuration
using `instructor`. You can take a look at the dictionary structure in
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
Runtime parameters:
- `base_url`: the base URL to use for the OpenAI API requests. Defaults to `None`.
- `api_key`: the API key to authenticate the requests to the OpenAI API. Defaults
to `None`.
- `max_retries`: the maximum number of times to retry the request to the API before
failing. Defaults to `6`.
- `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
Icon:
`:simple-openai:`
Examples:
Generate text:
```python
from distilabel.models.llms import OpenAILLM
llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key")
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate text from a custom endpoint following the OpenAI API:
```python
from distilabel.models.llms import OpenAILLM
llm = OpenAILLM(
model="prometheus-eval/prometheus-7b-v2.0",
base_url=r"http://localhost:8080/v1"
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
```python
from pydantic import BaseModel
from distilabel.models.llms import OpenAILLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = OpenAILLM(
model="gpt-4-turbo",
api_key="api.key",
structured_output={"schema": User}
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
Generate with Batch API (offline batch generation):
```python
from distilabel.models.llms import OpenAILLM
load = llm = OpenAILLM(
model="gpt-3.5-turbo",
use_offline_batch_generation=True,
offline_batch_generation_block_until_done=5, # poll for results every 5 seconds
)
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
# [['Hello! How can I assist you today?']]
```
"""
def load(self) -> None:
AsyncLLM.load(self)
OpenAIBaseClient.load(self)
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
num_generations: int = 1,
max_new_tokens: NonNegativeInt = 128,
logprobs: bool = False,
top_logprobs: Optional[PositiveInt] = None,
echo: bool = False,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
stop: Optional[Union[str, List[str]]] = None,
response_format: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the OpenAI async
client.
Args:
input: a single input in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
logprobs: whether to return the log probabilities or not. Defaults to `False`.
top_logprobs: the number of top log probabilities to return per output token
generated. Defaults to `None`.
echo: whether to echo the input in the response or not. It's only used if the
`input` argument is an `str`. Defaults to `False`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
stop: a string or a list of strings to use as a stop sequence for the generation.
Defaults to `None`.
response_format: the format of the response to return. Must be one of
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
for more information on how to use the JSON model from OpenAI. Defaults to None
which returns text. To return JSON, use {"type": "json_object"}.
extra_body: an optional dictionary containing extra body parameters that will
be sent to the OpenAI API endpoint. Defaults to `None`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
if isinstance(input, str):
return await self._generate_completion(
input=input,
num_generations=num_generations,
max_new_tokens=max_new_tokens,
echo=echo,
top_logprobs=top_logprobs,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
extra_body=extra_body,
)
return await self._generate_chat_completion(
input=input,
num_generations=num_generations,
max_new_tokens=max_new_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
stop=stop,
response_format=response_format,
extra_body=extra_body,
)
async def _generate_completion(
self,
input: str,
num_generations: int = 1,
max_new_tokens: int = 128,
echo: bool = False,
top_logprobs: Optional[PositiveInt] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
completion = await self._aclient.completions.create(
prompt=input,
echo=echo,
model=self.model,
n=num_generations,
max_tokens=max_new_tokens,
logprobs=top_logprobs,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
extra_body=extra_body,
)
generations = []
logprobs = []
for choice in completion.choices:
generations.append(choice.text)
if choice_logprobs := self._get_logprobs_from_completion_choice(choice):
logprobs.append(choice_logprobs)
statistics = self._get_llm_statistics(completion)
return prepare_output(
generations=generations,
input_tokens=statistics["input_tokens"],
output_tokens=statistics["output_tokens"],
logprobs=logprobs,
)
def _get_logprobs_from_completion_choice(
self, choice: "OpenAICompletionChoice"
) -> Union[List[Union[List["Logprob"], None]], None]:
if choice.logprobs is None or choice.logprobs.top_logprobs is None:
return None
return [
[
{"token": token, "logprob": token_logprob}
for token, token_logprob in logprobs.items()
]
if logprobs is not None
else None
for logprobs in choice.logprobs.top_logprobs
]
async def _generate_chat_completion(
self,
input: Union["StandardInput", "StructuredInput"],
num_generations: int = 1,
max_new_tokens: int = 128,
logprobs: bool = False,
top_logprobs: Optional[PositiveInt] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
stop: Optional[Union[str, List[str]]] = None,
response_format: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
structured_output = None
if isinstance(input, tuple):
input, structured_output = input
result = self._prepare_structured_output(
structured_output=structured_output, # type: ignore
client=self._aclient,
framework="openai",
)
self._aclient = result.get("client") # type: ignore
if structured_output is None and self.structured_output is not None:
structured_output = self.structured_output
kwargs = {
"messages": input, # type: ignore
"model": self.model,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"max_tokens": max_new_tokens,
"n": num_generations,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"temperature": temperature,
"top_p": top_p,
"stop": stop,
"extra_body": extra_body,
}
# Checks if any message contains an image, in that case "stop" cannot be used or
# raises an error in the API.
if isinstance(
[row for row in input if row["role"] == "user"][0]["content"], list
):
kwargs.pop("stop")
if response_format is not None:
kwargs["response_format"] = response_format
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
if structured_output:
# NOTE: `instructor` doesn't work with `n` parameter, so it will always return
# only 1 choice.
statistics = self._get_llm_statistics(completion._raw_response)
if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
completion._raw_response.choices[0]
):
output_logprobs = [choice_logprobs]
else:
output_logprobs = None
return prepare_output(
generations=[completion.model_dump_json()],
input_tokens=statistics["input_tokens"],
output_tokens=statistics["output_tokens"],
logprobs=output_logprobs,
)
return self._generations_from_openai_completion(completion)
def _generations_from_openai_completion(
self, completion: "OpenAIChatCompletion"
) -> "GenerateOutput":
"""Get the generations from the OpenAI Chat Completion object.
Args:
completion: the completion object to get the generations from.
Returns:
A list of strings containing the generated responses for the input.
"""
generations = []
logprobs = []
for choice in completion.choices:
if (content := choice.message.content) is None:
self._logger.warning( # type: ignore
f"Received no response using OpenAI client (model: '{self.model}')."
f" Finish reason was: {choice.finish_reason}"
)
generations.append(content)
if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
choice
):
logprobs.append(choice_logprobs)
statistics = self._get_llm_statistics(completion)
return prepare_output(
generations=generations,
input_tokens=statistics["input_tokens"],
output_tokens=statistics["output_tokens"],
logprobs=logprobs,
)
def _get_logprobs_from_chat_completion_choice(
self, choice: "OpenAIChatCompletionChoice"
) -> Union[List[List["Logprob"]], None]:
if choice.logprobs is None or choice.logprobs.content is None:
return None
return [
[
{"token": top_logprob.token, "logprob": top_logprob.logprob}
for top_logprob in token_logprobs.top_logprobs
]
for token_logprobs in choice.logprobs.content
]
def offline_batch_generate(
self,
inputs: Union[List["FormattedInput"], None] = None,
num_generations: int = 1,
max_new_tokens: int = 128,
logprobs: bool = False,
top_logprobs: Optional[PositiveInt] = None,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
stop: Optional[Union[str, List[str]]] = None,
response_format: Optional[str] = None,
**kwargs: Any,
) -> List["GenerateOutput"]:
"""Uses the OpenAI batch API to generate `num_generations` responses for the given
inputs.
Args:
inputs: a list of inputs in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
logprobs: whether to return the log probabilities or not. Defaults to `False`.
top_logprobs: the number of top log probabilities to return per output token
generated. Defaults to `None`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
stop: a string or a list of strings to use as a stop sequence for the generation.
Defaults to `None`.
response_format: the format of the response to return. Must be one of
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
for more information on how to use the JSON model from OpenAI. Defaults to `text`.
Returns:
A list of lists of strings containing the generated responses for each input
in `inputs`.
Raises:
DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
is not finished yet.
ValueError: if no job IDs were found to retrieve the results from.
"""
if self.jobs_ids:
return self._check_and_get_batch_results()
if inputs:
self.jobs_ids = self._create_jobs(
inputs=inputs,
**{
"model": self.model,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"max_tokens": max_new_tokens,
"n": num_generations,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"temperature": temperature,
"top_p": top_p,
"stop": stop,
"response_format": response_format,
},
)
raise DistilabelOfflineBatchGenerationNotFinishedException(
jobs_ids=self.jobs_ids
)
raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")
def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
"""Checks the status of the batch jobs and retrieves the results from the OpenAI
Batch API.
Returns:
A list of lists of strings containing the generated responses for each input.
Raises:
ValueError: if no job IDs were found to retrieve the results from.
DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
is not finished yet.
RuntimeError: if the only batch job found failed.
"""
if not self.jobs_ids:
raise ValueError("No job IDs were found to retrieve the results from.")
outputs = []
for batch_id in self.jobs_ids:
batch = self._get_openai_batch(batch_id)
if batch.status in ("validating", "in_progress", "finalizing"):
raise DistilabelOfflineBatchGenerationNotFinishedException(
jobs_ids=self.jobs_ids
)
if batch.status in ("failed", "expired", "cancelled", "cancelling"):
self._logger.error( # type: ignore
f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
)
if len(self.jobs_ids) == 1:
self.jobs_ids = None
raise RuntimeError(
f"The only OpenAI API Batch that was created with ID '{batch_id}'"
f" failed with status '{batch.status}'."
)
continue
outputs.extend(self._retrieve_batch_results(batch))
# sort by `custom_id` to return the results in the same order as the inputs
outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
return [self._parse_output(output) for output in outputs]
def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
"""Parses the output from the OpenAI Batch API into a list of strings.
Args:
output: the output to parse.
Returns:
A list of strings containing the generated responses for the input.
"""
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
if "response" not in output:
return []
if output["response"]["status_code"] != 200:
return []
return self._generations_from_openai_completion(
OpenAIChatCompletion(**output["response"]["body"])
)
def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
"""Gets a batch from the OpenAI Batch API.
Args:
batch_id: the ID of the batch to retrieve.
Returns:
The batch retrieved from the OpenAI Batch API.
Raises:
openai.OpenAIError: if there was an error while retrieving the batch from the
OpenAI Batch API.
"""
import openai
try:
return self._client.batches.retrieve(batch_id)
except openai.OpenAIError as e:
self._logger.error( # type: ignore
f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
)
raise e
def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
"""Retrieves the results of a batch from its output file, parsing the JSONL content
into a list of dictionaries.
Args:
batch: the batch to retrieve the results from.
Returns:
A list of dictionaries containing the results of the batch.
Raises:
AssertionError: if no output file ID was found in the batch.
"""
import openai
assert batch.output_file_id, "No output file ID was found in the batch."
try:
file_response = self._client.files.content(batch.output_file_id)
return [orjson.loads(line) for line in file_response.text.splitlines()]
except openai.OpenAIError as e:
self._logger.error( # type: ignore
f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
)
return []
def _create_jobs(
self, inputs: List["FormattedInput"], **kwargs: Any
) -> Tuple[str, ...]:
"""Creates jobs in the OpenAI Batch API to generate responses for the given inputs.
Args:
inputs: a list of inputs in chat format to generate responses for.
kwargs: the keyword arguments to use for the generation.
Returns:
A list of job IDs created in the OpenAI Batch API.
"""
batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
jobs = []
for batch_input_file in batch_input_files:
if batch := self._create_batch_api_job(batch_input_file):
jobs.append(batch.id)
return tuple(jobs)
def _create_batch_api_job(
self, batch_input_file: "OpenAIFileObject"
) -> Union["OpenAIBatch", None]:
"""Creates a job in the OpenAI Batch API to generate responses for the given input
file.
Args:
batch_input_file: the input file to generate responses for.
Returns:
The batch job created in the OpenAI Batch API.
"""
import openai
metadata = {"description": "distilabel"}
if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
metadata["distilabel_pipeline_name"] = distilabel_pipeline_name
if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id
batch = None
try:
batch = self._client.batches.create(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file.id,
metadata=metadata,
)
except openai.OpenAIError as e:
self._logger.error( # type: ignore
f"Error while creating OpenAI Batch API job for file with ID"
f" '{batch_input_file.id}': {e}."
)
raise e
return batch
def _create_batch_files(
self, inputs: List["FormattedInput"], **kwargs: Any
) -> List["OpenAIFileObject"]:
"""Creates the necessary input files for the batch API to generate responses. The
maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
need to split the inputs into multiple files if necessary.
More information: https://platform.openai.com/docs/api-reference/files/create
Args:
inputs: a list of inputs in chat format to generate responses for, optionally
including structured output.
kwargs: the keyword arguments to use for the generation.
Returns:
The list of file objects created for the OpenAI Batch API.
Raises:
openai.OpenAIError: if there was an error while creating the batch input file
in the OpenAI Batch API.
"""
import openai
files = []
for file_no, buffer in enumerate(
self._create_jsonl_buffers(inputs=inputs, **kwargs)
):
try:
# TODO: add distilabel pipeline name and id
batch_input_file = self._client.files.create(
file=(self._name_for_openai_files(file_no), buffer),
purpose="batch",
)
files.append(batch_input_file)
except openai.OpenAIError as e:
self._logger.error( # type: ignore
f"Error while creating OpenAI batch input file: {e}"
)
raise e
return files
def _create_jsonl_buffers(
self, inputs: List["FormattedInput"], **kwargs: Any
) -> Generator[io.BytesIO, None, None]:
"""Creates a generator of buffers containing the JSONL formatted inputs to be
used by the OpenAI Batch API. The buffers created are of size 100MB or less.
Args:
inputs: a list of inputs in chat format to generate responses for, optionally
including structured output.
kwargs: the keyword arguments to use for the generation.
Yields:
A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
API.
"""
buffer = io.BytesIO()
buffer_current_size = 0
for i, input in enumerate(inputs):
# We create the smallest `custom_id` so we don't increase the size of the file
# to much, but we can still sort the results with the order of the inputs.
row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
row_size = len(row)
if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
buffer.seek(0)
yield buffer
buffer = io.BytesIO()
buffer_current_size = 0
buffer.write(row)
buffer_current_size += row_size
if buffer_current_size > 0:
buffer.seek(0)
yield buffer
def _create_jsonl_row(
self, input: "FormattedInput", custom_id: str, **kwargs: Any
) -> bytes:
"""Creates a JSONL formatted row to be used by the OpenAI Batch API.
Args:
input: a list of inputs in chat format to generate responses for, optionally
including structured output.
custom_id: a custom ID to use for the row.
kwargs: the keyword arguments to use for the generation.
Returns:
A JSONL formatted row to be used by the OpenAI Batch API.
"""
# TODO: depending on the format of the input, add `response_format` to the kwargs
row = {
"custom_id": custom_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {"messages": input, **kwargs},
}
json_row = orjson.dumps(row)
return json_row + b"\n"
def _name_for_openai_files(self, file_no: int) -> str:
if (
envs.DISTILABEL_PIPELINE_NAME is None
or envs.DISTILABEL_PIPELINE_CACHE_ID is None
):
return f"distilabel-pipeline-fileno-{file_no}.jsonl"
return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl"
@staticmethod
def _get_llm_statistics(
completion: Union["OpenAIChatCompletion", "OpenAICompletion"],
) -> "LLMStatistics":
return {
"output_tokens": [
completion.usage.completion_tokens if completion.usage else 0
],
"input_tokens": [completion.usage.prompt_tokens if completion.usage else 0],
}
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
from pydantic import Field, PrivateAttr, SecretStr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.openai import OpenAILLM
_TOGETHER_API_KEY_ENV_VAR_NAME = "TOGETHER_API_KEY"
class TogetherLLM(OpenAILLM):
"""TogetherLLM LLM implementation running the async API client of OpenAI.
Attributes:
model: the model name to use for the LLM e.g. "mistralai/Mixtral-8x7B-Instruct-v0.1".
Supported models can be found [here](https://api.together.xyz/models).
base_url: the base URL to use for the Together API can be set with `TOGETHER_BASE_URL`.
Defaults to `None` which means that the value set for the environment variable
`TOGETHER_BASE_URL` will be used, or "https://api.together.xyz/v1" if not set.
api_key: the API key to authenticate the requests to the Together API. Defaults to `None`
which means that the value set for the environment variable `TOGETHER_API_KEY` will be
used, or `None` if not set.
_api_key_env_var: the name of the environment variable to use for the API key. It
is meant to be used internally.
Examples:
Generate text:
```python
from distilabel.models.llms import AnyscaleLLM
llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key")
llm.load()
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
base_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv(
"TOGETHER_BASE_URL", "https://api.together.xyz/v1"
),
description="The base URL to use for the Together API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_TOGETHER_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Together API.",
)
_api_key_env_var: str = PrivateAttr(_TOGETHER_API_KEY_ENV_VAR_NAME)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from distilabel.typing import ChatType
if TYPE_CHECKING:
from distilabel.typing import GenerateOutput, LLMLogprobs, LLMOutput
def compute_tokens(
text_or_messages: Union[str, ChatType], tokenizer: Callable[..., List[int]]
) -> int:
"""Helper function to count the number of tokens in a text or list of messages.
Args:
text_or_messages: Either a string response or a list of messages.
tokenizer: A callable function that take str and returns the tokenized version of the text.
Returns:
The number of tokens.
"""
if isinstance(text_or_messages, list):
return sum([len(tokenizer(message["content"])) for message in text_or_messages])
else:
return len(tokenizer(text_or_messages))
def prepare_output(
generations: "LLMOutput",
input_tokens: Optional[List[int]] = None,
output_tokens: Optional[List[int]] = None,
logprobs: Optional["LLMLogprobs"] = None,
) -> "GenerateOutput":
"""Helper function to prepare the output of the LLM.
Args:
generations: The outputs from an LLM.
input_tokens: The number of tokens of the inputs. Defaults to `None`.
output_tokens: The number of tokens of the LLM response. Defaults to `None`.
logprobs: The logprobs of the LLM response. Defaults to `None`.
Returns:
Output generation from an LLM.
"""
output: "GenerateOutput" = {
"generations": generations,
"statistics": {},
}
if input_tokens:
output["statistics"]["input_tokens"] = input_tokens
if output_tokens:
output["statistics"]["output_tokens"] = output_tokens
if logprobs:
output["logprobs"] = logprobs
return output
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from pydantic import PrivateAttr, validate_call
from typing_extensions import TypedDict
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.utils import prepare_output
from distilabel.typing import GenerateOutput, StandardInput
if TYPE_CHECKING:
from vertexai.generative_models import Content, GenerationResponse, GenerativeModel
from distilabel.typing import LLMStatistics
class VertexChatItem(TypedDict):
role: Literal["user", "model"]
content: str
VertexChatType = List[VertexChatItem]
"""VertexChatType is a type alias for a `list` of `dict`s following the VertexAI conversational format."""
class VertexAILLM(AsyncLLM):
"""VertexAI LLM implementation running the async API clients for Gemini.
- Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
To use the `VertexAILLM` is necessary to have configured the Google Cloud authentication
using one of these methods:
- Setting `GOOGLE_CLOUD_CREDENTIALS` environment variable
- Using `gcloud auth application-default login` command
- Using `vertexai.init` function from the `google-cloud-aiplatform` library
Attributes:
model: the model name to use for the LLM e.g. "gemini-1.0-pro". [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models).
_aclient: the `GenerativeModel` to use for the Vertex AI Gemini API. It is meant
to be used internally. Set in the `load` method.
Icon:
`:simple-googlecloud:`
Examples:
Generate text:
```python
from distilabel.models.llms import VertexAILLM
llm = VertexAILLM(model="gemini-1.5-pro")
llm.load()
# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
"""
model: str
_num_generations_param_supported = False
_aclient: Optional["GenerativeModel"] = PrivateAttr(...)
def load(self) -> None:
"""Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
super().load()
try:
from vertexai.generative_models import GenerationConfig, GenerativeModel
self._generation_config_class = GenerationConfig
except ImportError as e:
raise ImportError(
"vertexai is not installed. Please install it using"
" `pip install 'distilabel[vertexai]'`."
) from e
if _is_gemini_model(self.model):
self._aclient = GenerativeModel(model_name=self.model)
else:
raise NotImplementedError(
"`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
)
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
"""Converts a chat type to a list of content items expected by the API.
Args:
input: the chat type to be converted.
Returns:
List[str]: a list of content items expected by the API.
"""
from vertexai.generative_models import Content, Part
contents = []
for message in input:
if message["role"] not in ["user", "model"]:
raise ValueError(
"`VertexAILLM only supports the roles 'user' or 'model'."
)
contents.append(
Content(
role=message["role"], parts=[Part.from_text(message["content"])]
)
)
return contents
@validate_call
async def agenerate( # type: ignore
self,
input: VertexChatType,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
safety_settings: Optional[Dict[str, Any]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
) -> GenerateOutput:
"""Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).
Args:
input: a single input in chat format to generate responses for.
temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
top_k: If specified, top-k sampling will be used. Defaults to `None`.
max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
stop_sequences: A list of stop sequences. Defaults to `None`.
safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
tools: A potential list of tools that can be used by the API. Defaults to `None`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
from vertexai.generative_models import GenerationConfig
content: "GenerationResponse" = await self._aclient.generate_content_async( # type: ignore
contents=self._chattype_to_content(input),
generation_config=GenerationConfig(
candidate_count=1, # only one candidate allowed per call
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
),
safety_settings=safety_settings, # type: ignore
tools=tools, # type: ignore
stream=False,
)
text = None
try:
text = content.candidates[0].text
except ValueError:
self._logger.warning( # type: ignore
f"Received no response using VertexAI client (model: '{self.model}')."
f" Finish reason was: '{content.candidates[0].finish_reason}'."
)
return prepare_output([text], **self._get_llm_statistics(content))
@staticmethod
def _get_llm_statistics(content: "GenerationResponse") -> "LLMStatistics":
return {
"input_tokens": [content.usage_metadata.prompt_token_count],
"output_tokens": [content.usage_metadata.candidates_token_count],
}
def _is_gemini_model(model: str) -> bool:
"""Returns `True` if the model is a model from the Vertex AI Gemini API.
Args:
model (str): the model name to be checked.
Returns:
bool: `True` if the model is a model from the Vertex AI Gemini API.
"""
return "gemini" in model
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import gc
import json
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
)
from pydantic import Field, PositiveInt, PrivateAttr, SecretStr, validate_call
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.typing import (
FormattedInput,
GenerateOutput,
Logprob,
OutlinesStructuredOutputType,
)
if TYPE_CHECKING:
from openai import OpenAI # noqa
from transformers import PreTrainedTokenizer
from vllm import LLM as _vLLM
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs, PromptLogprobs
from distilabel.typing import (
StandardInput,
StructuredInput,
LLMStatistics,
LLMLogprobs,
LLMOutput,
)
LogitsProcessorFn = Union[
Callable[[List[int], Any], Any],
Callable[[List[int], List[int], Any], Any],
]
LogitsProcessors = List[LogitsProcessorFn]
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
"""`vLLM` library LLM implementation.
Attributes:
model: the model Hugging Face Hub repo id or a path to a directory containing the
model weights and configuration files.
dtype: the data type to use for the model. Defaults to `auto`.
trust_remote_code: whether to trust the remote code when loading the model. Defaults
to `False`.
quantization: the quantization mode to use for the model. Defaults to `None`.
revision: the revision of the model to load. Defaults to `None`.
tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
the tokenizer files. If not provided, the tokenizer will be loaded from the
model directory. Defaults to `None`.
tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`.
tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults
to `False`.
chat_template: a chat template that will be used to build the prompts before
sending them to the model. If not provided, the chat template defined in the
tokenizer config will be used. If not provided and the tokenizer doesn't have
a chat template, then ChatML template will be used. Defaults to `None`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
seed: the seed to use for the random number generator. Defaults to `0`.
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
`LLM` class of `vllm` library. Defaults to `{}`.
_model: the `vLLM` model instance. This attribute is meant to be used internally
and should not be accessed directly. It will be set in the `load` method.
_tokenizer: the tokenizer instance used to format the prompt before passing it to
the `LLM`. This attribute is meant to be used internally and should not be
accessed directly. It will be set in the `load` method.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
References:
- https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
Runtime parameters:
- `extra_kwargs`: additional dictionary of keyword arguments that will be passed to
the `LLM` class of `vllm` library.
Examples:
Generate text:
```python
from distilabel.models.llms import vLLM
# You can pass a custom chat_template to the model
llm = vLLM(
model="prometheus-eval/prometheus-7b-v2.0",
chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
)
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
```python
from pathlib import Path
from distilabel.models.llms import vLLM
class User(BaseModel):
name: str
last_name: str
id: int
llm = vLLM(
model="prometheus-eval/prometheus-7b-v2.0"
structured_output={"format": "json", "schema": Character},
)
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
"""
model: str
dtype: str = "auto"
trust_remote_code: bool = False
quantization: Optional[str] = None
revision: Optional[str] = None
tokenizer: Optional[str] = None
tokenizer_mode: Literal["auto", "slow"] = "auto"
tokenizer_revision: Optional[str] = None
skip_tokenizer_init: bool = False
chat_template: Optional[str] = None
seed: int = 0
extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
default_factory=dict,
description="Additional dictionary of keyword arguments that will be passed to the"
" `vLLM` class of `vllm` library. See all the supported arguments at: "
"https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py",
)
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)
_model: "_vLLM" = PrivateAttr(None)
_tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)
_structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None)
def load(self) -> None:
"""Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
default value is ChatML format, unless explicitly provided.
"""
super().load()
CudaDevicePlacementMixin.load(self)
try:
from vllm import LLM as _vLLM
except ImportError as ie:
raise ImportError(
"vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`."
) from ie
self._model = _vLLM(
self.model,
dtype=self.dtype,
trust_remote_code=self.trust_remote_code,
quantization=self.quantization,
revision=self.revision,
tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode,
tokenizer_revision=self.tokenizer_revision,
skip_tokenizer_init=self.skip_tokenizer_init,
seed=self.seed,
**self.extra_kwargs, # type: ignore
)
self._tokenizer = self._model.get_tokenizer() # type: ignore
if self.chat_template is not None:
self._tokenizer.chat_template = self.chat_template # type: ignore
if self.structured_output:
self._structured_output_logits_processor = self._prepare_structured_output(
self.structured_output
)
def unload(self) -> None:
"""Unloads the `vLLM` model."""
self._cleanup_vllm_model()
self._model = None # type: ignore
self._tokenizer = None # type: ignore
CudaDevicePlacementMixin.unload(self)
super().unload()
def _cleanup_vllm_model(self) -> None:
if self._model is None:
return
import torch # noqa
from vllm.distributed.parallel_state import (
destroy_distributed_environment,
destroy_model_parallel,
)
destroy_model_parallel()
destroy_distributed_environment()
del self._model.llm_engine.model_executor
del self._model
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
def prepare_input(self, input: Union["StandardInput", str]) -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
if isinstance(input, str):
return input
prompt: str = (
self._tokenizer.apply_chat_template(
input, # type: ignore
tokenize=False,
add_generation_prompt=True, # type: ignore
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
def _prepare_batches(
self, inputs: List["StructuredInput"]
) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]:
"""Prepares the inputs by grouping them by the structured output.
When we generate structured outputs with schemas obtained from a dataset, we need to
prepare the data to try to send batches of inputs instead of single inputs to the model
to take advante of the engine. So we group the inputs by the structured output to be
passed in the `generate` method.
Args:
inputs: The batch of inputs passed to the generate method. As we expect to be generating
structured outputs, each element will be a tuple containing the instruction and the
structured output.
Returns:
The prepared batches (sub-batches let's say) to be passed to the `generate` method.
Each new tuple will contain instead of the single instruction, a list of instructions
"""
instruction_order = {}
batches: Dict[str, List[str]] = {}
for i, (instruction, structured_output) in enumerate(inputs):
instruction = self.prepare_input(instruction)
instruction_order[instruction] = i
structured_output = json.dumps(structured_output)
if structured_output not in batches:
batches[structured_output] = [instruction]
else:
batches[structured_output].append(instruction)
# Built a list with instructions sorted by structured output
flat_instructions = [
instruction for _, group in batches.items() for instruction in group
]
# Generate the list of indices based on the original order
sorted_indices = [
instruction_order[instruction] for instruction in flat_instructions
]
return [
(batch, json.loads(schema)) for schema, batch in batches.items()
], sorted_indices
@validate_call
def generate( # noqa: C901 # type: ignore
self,
inputs: List[FormattedInput],
num_generations: int = 1,
max_new_tokens: int = 128,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
logprobs: Optional[PositiveInt] = None,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
skip_special_tokens: bool = True,
logits_processors: Optional[LogitsProcessors] = None,
extra_sampling_params: Optional[Dict[str, Any]] = None,
echo: bool = False,
) -> List[GenerateOutput]:
"""Generates `num_generations` responses for each input.
Args:
inputs: a list of inputs in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`.
repetition_penalty: the repetition penalty to use for the generation Defaults to
`1.0`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
top_k: the top-k value to use for the generation. Defaults to `0`.
min_p: the minimum probability to use for the generation. Defaults to `0.0`.
logprobs: number of log probabilities to return per output token. If `None`,
then no log probability won't be returned. Defaults to `None`.
stop: a list of strings that will be used to stop the generation when found.
Defaults to `None`.
stop_token_ids: a list of token ids that will be used to stop the generation
when found. Defaults to `None`.
include_stop_str_in_output: whether to include the stop string in the output.
Defaults to `False`.
skip_special_tokens: whether to exclude special tokens from the output. Defaults
to `False`.
logits_processors: a list of functions to process the logits before sampling.
Defaults to `None`.
extra_sampling_params: dictionary with additional arguments to be passed to
the `SamplingParams` class from `vllm`.
echo: whether to echo the include the prompt in the response or not. Defaults
to `False`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
from vllm import SamplingParams
if not logits_processors:
logits_processors = []
if extra_sampling_params is None:
extra_sampling_params = {}
structured_output = None
if isinstance(inputs[0], tuple):
# Prepare the batches for structured generation
prepared_batches, sorted_indices = self._prepare_batches(inputs) # type: ignore
else:
# Simulate a batch without the structured output content
prepared_batches = [([self.prepare_input(input) for input in inputs], None)] # type: ignore
sorted_indices = None
# Case in which we have a single structured output for the dataset
if self._structured_output_logits_processor:
logits_processors.append(self._structured_output_logits_processor)
batched_outputs: List["LLMOutput"] = []
generations = []
for prepared_inputs, structured_output in prepared_batches:
if self.structured_output is not None and structured_output is not None:
self._logger.warning(
"An `structured_output` was provided in the model configuration, but"
" one was also provided in the input. The input structured output will"
" be used."
)
if structured_output is not None:
logits_processors.append(
self._prepare_structured_output(structured_output) # type: ignore
)
sampling_params = SamplingParams( # type: ignore
n=num_generations,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
max_tokens=max_new_tokens,
prompt_logprobs=logprobs if echo else None,
logprobs=logprobs,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
skip_special_tokens=skip_special_tokens,
logits_processors=logits_processors,
**extra_sampling_params,
)
batch_outputs: List["RequestOutput"] = self._model.generate(
prompts=prepared_inputs,
sampling_params=sampling_params,
use_tqdm=False,
)
# Remove structured output logit processor to avoid stacking structured output
# logits processors that leads to non-sense generations
if structured_output is not None:
logits_processors.pop(-1)
for input, outputs in zip(prepared_inputs, batch_outputs):
processed_prompt_logprobs = []
if outputs.prompt_logprobs is not None:
processed_prompt_logprobs = self._get_llm_logprobs(
outputs.prompt_logprobs
)
texts, statistics, outputs_logprobs = self._process_outputs(
input=input,
outputs=outputs,
echo=echo,
prompt_logprobs=processed_prompt_logprobs,
)
batched_outputs.append(texts)
generation = prepare_output(
generations=texts,
input_tokens=statistics["input_tokens"],
output_tokens=statistics["output_tokens"],
logprobs=outputs_logprobs,
)
generations.append(generation)
if sorted_indices is not None:
pairs = list(enumerate(sorted_indices))
pairs.sort(key=lambda x: x[1])
generations = [generations[original_idx] for original_idx, _ in pairs]
return generations
def _process_outputs(
self,
input: str,
outputs: "RequestOutput",
prompt_logprobs: List[List["Logprob"]],
echo: bool = False,
) -> Tuple["LLMOutput", "LLMStatistics", "LLMLogprobs"]:
texts = []
outputs_logprobs = []
statistics = {
"input_tokens": [compute_tokens(input, self._tokenizer.encode)]
* len(outputs.outputs),
"output_tokens": [],
}
for output in outputs.outputs:
text = output.text
if echo:
text = input + text
texts.append(text)
statistics["output_tokens"].append(len(output.token_ids))
if output.logprobs is not None:
processed_output_logprobs = self._get_llm_logprobs(output.logprobs)
outputs_logprobs.append(prompt_logprobs + processed_output_logprobs)
return texts, statistics, outputs_logprobs
def _prepare_structured_output( # type: ignore
self, structured_output: "OutlinesStructuredOutputType"
) -> Union[Callable, None]:
"""Creates the appropriate function to filter tokens to generate structured outputs.
Args:
structured_output: the configuration dict to prepare the structured output.
Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)
assert structured_output is not None, "`structured_output` cannot be `None`"
result = prepare_guided_output(structured_output, "vllm", self._model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema
return result["processor"]
def _get_llm_logprobs(
self, logprobs: Union["PromptLogprobs", "SampleLogprobs"]
) -> List[List["Logprob"]]:
processed_logprobs = []
for token_logprob in logprobs: # type: ignore
token_logprobs = []
if token_logprob is None:
processed_logprobs.append(None)
continue
for logprob in token_logprob.values():
token_logprobs.append(
{"token": logprob.decoded_token, "logprob": logprob.logprob}
)
processed_logprobs.append(token_logprobs)
return processed_logprobs
class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
"""A client for the `vLLM` server implementing the OpenAI API specification.
Attributes:
base_url: the base URL of the `vLLM` server. Defaults to `"http://localhost:8000"`.
max_retries: the maximum number of times to retry the request to the API before
failing. Defaults to `6`.
timeout: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
httpx_client_kwargs: extra kwargs that will be passed to the `httpx.AsyncClient`
created to comunicate with the `vLLM` server. Defaults to `None`.
tokenizer: the Hugging Face Hub repo id or path of the tokenizer that will be used
to apply the chat template and tokenize the inputs before sending it to the
server. Defaults to `None`.
tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
_aclient: the `httpx.AsyncClient` used to comunicate with the `vLLM` server. Defaults
to `None`.
Runtime parameters:
- `base_url`: the base url of the `vLLM` server. Defaults to `"http://localhost:8000"`.
- `max_retries`: the maximum number of times to retry the request to the API before
failing. Defaults to `6`.
- `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
- `httpx_client_kwargs`: extra kwargs that will be passed to the `httpx.AsyncClient`
created to comunicate with the `vLLM` server. Defaults to `None`.
Examples:
Generate text:
```python
from distilabel.models.llms import ClientvLLM
llm = ClientvLLM(
base_url="http://localhost:8000/v1",
tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
)
llm.load()
results = llm.generate_outputs(
inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
temperature=0.7,
top_p=1.0,
max_new_tokens=256,
)
# [
# [
# "I'm functioning properly, thank you for asking. How can I assist you today?",
# "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
# "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
# ]
# ]
```
"""
model: str = "" # Default value so it's not needed to `ClientvLLM(model="...")`
tokenizer: Optional[str] = None
tokenizer_revision: Optional[str] = None
# We need the sync client to get the list of models
_client: "OpenAI" = PrivateAttr(None)
_tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)
def load(self) -> None:
"""Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
optionally."""
self.api_key = SecretStr("EMPTY")
# We need to first create the sync client to get the model name that will be used
# in the `super().load()` when creating the logger.
try:
from openai import OpenAI
except ImportError as ie:
raise ImportError(
"OpenAI Python client is not installed. Please install it using"
" `pip install 'distilabel[openai]'`."
) from ie
self._client = OpenAI(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(), # type: ignore
max_retries=self.max_retries, # type: ignore
timeout=self.timeout,
)
super().load()
try:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"To use `ClientvLLM` you need to install `transformers`."
"Please install it using `pip install 'distilabel[hf-transformers]'`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer, revision=self.tokenizer_revision
)
@cached_property
def model_name(self) -> str: # type: ignore
"""Returns the name of the model served with vLLM server."""
models = self._client.models.list()
return models.data[0].id
def _prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.
Args:
input: the input list containing chat items.
Returns:
The prompt to send to the LLM.
"""
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
tokenize=False,
add_generation_prompt=True, # type: ignore
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)
@validate_call
async def agenerate( # type: ignore
self,
input: FormattedInput,
num_generations: int = 1,
max_new_tokens: int = 128,
frequency_penalty: float = 0.0,
logit_bias: Optional[Dict[str, int]] = None,
presence_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
) -> GenerateOutput:
"""Generates `num_generations` responses for each input.
Args:
input: a single input in chat format to generate responses for.
num_generations: the number of generations to create per input. Defaults to
`1`.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
frequency_penalty: the repetition penalty to use for the generation. Defaults
to `0.0`.
logit_bias: modify the likelihood of specified tokens appearing in the completion.
Defaults to ``
presence_penalty: the presence penalty to use for the generation. Defaults to
`0.0`.
temperature: the temperature to use for the generation. Defaults to `0.1`.
top_p: nucleus sampling. The value refers to the top-p tokens that should be
considered for sampling. Defaults to `1.0`.
Returns:
A list of lists of strings containing the generated responses for each input.
"""
completion = await self._aclient.completions.create(
model=self.model_name,
prompt=self._prepare_input(input), # type: ignore
n=num_generations,
max_tokens=max_new_tokens,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
presence_penalty=presence_penalty,
temperature=temperature,
top_p=top_p,
)
generations = []
for choice in completion.choices:
text = choice.text
if text == "":
self._logger.warning( # type: ignore
f"Received no response from vLLM server (model: '{self.model_name}')."
f" Finish reason was: {choice.finish_reason}"
)
generations.append(text)
return prepare_output(generations, **self._get_llm_statistics(completion))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment