Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Union
from pydantic import BaseModel, ConfigDict, PrivateAttr
from distilabel.mixins.runtime_parameters import RuntimeParametersMixin
from distilabel.utils.serialization import _Serializable
if TYPE_CHECKING:
from logging import Logger
class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
"""Base class for `Embeddings` models.
To implement an `Embeddings` subclass, you need to subclass this class and implement:
- `load` method to load the `Embeddings` model. Don't forget to call `super().load()`,
so the `_logger` attribute is initialized.
- `model_name` property to return the model name used for the `Embeddings`.
- `encode` method to generate the sentence embeddings.
Attributes:
_logger: the logger to be used for the `Embeddings` model. It will be initialized
when the `load` method is called.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
protected_namespaces=(),
validate_default=True,
validate_assignment=True,
extra="forbid",
)
_logger: "Logger" = PrivateAttr(None)
def load(self) -> None:
"""Method to be called to initialize the `Embeddings`"""
self._logger = logging.getLogger(
f"distilabel.models.embeddings.{self.model_name}"
)
def unload(self) -> None:
"""Method to be called to unload the `Embeddings` and release any resources."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Returns the model name used for the `Embeddings`."""
pass
@abstractmethod
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
pass
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import Field, PrivateAttr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
if TYPE_CHECKING:
from llama_cpp import Llama
class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin):
"""`LlamaCpp` library implementation for embedding generation.
Attributes:
model_name: contains the name of the GGUF quantized model, compatible with the
installed version of the `llama.cpp` Python bindings.
model_path: contains the path to the GGUF quantized model, compatible with the
installed version of the `llama.cpp` Python bindings.
repo_id: the Hugging Face Hub repository id.
verbose: whether to print verbose output. Defaults to `False`.
n_gpu_layers: number of layers to run on the GPU. Defaults to `-1` (use the GPU if available).
disable_cuda_device_placement: whether to disable CUDA device placement. Defaults to `True`.
normalize_embeddings: whether to normalize the embeddings. Defaults to `False`.
seed: RNG seed, -1 for random
n_ctx: Text context, 0 = from model
n_batch: Prompt processing maximum batch size
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
`Llama` class of `llama_cpp` library. Defaults to `{}`.
Runtime parameters:
- `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`.
- `verbose`: whether to print verbose output. Defaults to `False`.
- `normalize_embeddings`: whether to normalize the embeddings. Defaults to `False`.
- `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the
`Llama` class of `llama_cpp` library. Defaults to `{}`.
References:
- [Offline inference embeddings](https://llama-cpp-python.readthedocs.io/en/stable/#embeddings)
Examples:
Generate sentence embeddings using a local model:
```python
from pathlib import Path
from distilabel.models.embeddings import LlamaCppEmbeddings
# You can follow along this example downloading the following model running the following
# command in the terminal, that will download the model to the `Downloads` folder:
# curl -L -o ~/Downloads/all-MiniLM-L6-v2-Q2_K.gguf https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q2_K.gguf
model_path = "Downloads/"
model = "all-MiniLM-L6-v2-Q2_K.gguf"
embeddings = LlamaCppEmbeddings(
model=model,
model_path=str(Path.home() / model_path),
)
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
print(results)
embeddings.unload()
```
Generate sentence embeddings using a HuggingFace Hub model:
```python
from distilabel.models.embeddings import LlamaCppEmbeddings
# You need to set environment variable to download private model to the local machine
repo_id = "second-state/All-MiniLM-L6-v2-Embedding-GGUF"
model = "all-MiniLM-L6-v2-Q2_K.gguf"
embeddings = LlamaCppEmbeddings(model=model,repo_id=repo_id)
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
print(results)
embeddings.unload()
# [
# [-0.05447685346007347, -0.01623094454407692, ...],
# [4.4889533455716446e-05, 0.044016145169734955, ...],
# ]
```
Generate sentence embeddings with cpu:
```python
from pathlib import Path
from distilabel.models.embeddings import LlamaCppEmbeddings
# You can follow along this example downloading the following model running the following
# command in the terminal, that will download the model to the `Downloads` folder:
# curl -L -o ~/Downloads/all-MiniLM-L6-v2-Q2_K.gguf https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q2_K.gguf
model_path = "Downloads/"
model = "all-MiniLM-L6-v2-Q2_K.gguf"
embeddings = LlamaCppEmbeddings(
model=model,
model_path=str(Path.home() / model_path),
n_gpu_layers=0,
disable_cuda_device_placement=True,
)
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
print(results)
embeddings.unload()
# [
# [-0.05447685346007347, -0.01623094454407692, ...],
# [4.4889533455716446e-05, 0.044016145169734955, ...],
# ]
```
"""
model: str = Field(
description="The name of the model to use for embeddings.",
)
model_path: RuntimeParameter[str] = Field(
default=None,
description="The path to the GGUF quantized model, compatible with the installed version of the `llama.cpp` Python bindings.",
)
repo_id: RuntimeParameter[str] = Field(
default=None, description="The Hugging Face Hub repository id.", exclude=True
)
n_gpu_layers: RuntimeParameter[int] = Field(
default=-1,
description="The number of layers that will be loaded in the GPU.",
)
n_ctx: int = 512
n_batch: int = 512
seed: int = 4294967295
normalize_embeddings: RuntimeParameter[bool] = Field(
default=False,
description="Whether to normalize the embeddings.",
)
verbose: RuntimeParameter[bool] = Field(
default=False,
description="Whether to print verbose output from llama.cpp library.",
)
extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
default_factory=dict,
description="Additional dictionary of keyword arguments that will be passed to the"
" `Llama` class of `llama_cpp` library. See all the supported arguments at: "
"https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__",
)
_model: Optional["Llama"] = PrivateAttr(...)
def load(self) -> None:
"""Loads the `gguf` model using either the path or the Hugging Face Hub repository id."""
super().load()
CudaDevicePlacementMixin.load(self)
try:
from llama_cpp import Llama
except ImportError as ie:
raise ImportError(
"`llama-cpp-python` package is not installed. Please install it using"
" `pip install 'distilabel[llama-cpp]'`."
) from ie
if self.repo_id is not None:
# use repo_id to download the model
from huggingface_hub.utils import validate_repo_id
validate_repo_id(self.repo_id)
self._model = Llama.from_pretrained(
repo_id=self.repo_id,
filename=self.model,
n_gpu_layers=self.n_gpu_layers,
seed=self.seed,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
kwargs=self.extra_kwargs,
)
elif self.model_path is not None:
self._model = Llama(
model_path=str(Path(self.model_path) / self.model),
n_gpu_layers=self.n_gpu_layers,
seed=self.seed,
n_ctx=self.n_ctx,
n_batch=self.n_batch,
verbose=self.verbose,
embedding=True,
kwargs=self.extra_kwargs,
)
else:
raise ValueError("Either 'model_path' or 'repo_id' must be provided")
def unload(self) -> None:
"""Unloads the `gguf` model."""
CudaDevicePlacementMixin.unload(self)
self._model.close()
super().unload()
@property
def model_name(self) -> str:
"""Returns the name of the model."""
return self.model
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
return self._model.embed(inputs, normalize=self.normalize_embeddings)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from pydantic import Field, PrivateAttr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
class SentenceTransformerEmbeddings(Embeddings, CudaDevicePlacementMixin):
"""`sentence-transformers` library implementation for embedding generation.
Attributes:
model: the model Hugging Face Hub repo id or a path to a directory containing the
model weights and configuration files.
device: the name of the device used to load the model e.g. "cuda", "mps", etc.
Defaults to `None`.
prompts: a dictionary containing prompts to be used with the model. Defaults to
`None`.
default_prompt_name: the default prompt (in `prompts`) that will be applied to the
inputs. If not provided, then no prompt will be used. Defaults to `None`.
trust_remote_code: whether to allow fetching and executing remote code fetched
from the repository in the Hub. Defaults to `False`.
revision: if `model` refers to a Hugging Face Hub repository, then the revision
(e.g. a branch name or a commit id) to use. Defaults to `"main"`.
token: the Hugging Face Hub token that will be used to authenticate to the Hugging
Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
local configuration will be used. Defaults to `None`.
truncate_dim: the dimension to truncate the sentence embeddings. Defaults to `None`.
model_kwargs: extra kwargs that will be passed to the Hugging Face `transformers`
model class. Defaults to `None`.
tokenizer_kwargs: extra kwargs that will be passed to the Hugging Face `transformers`
tokenizer class. Defaults to `None`.
config_kwargs: extra kwargs that will be passed to the Hugging Face `transformers`
configuration class. Defaults to `None`.
precision: the dtype that will have the resulting embeddings. Defaults to `"float32"`.
normalize_embeddings: whether to normalize the embeddings so they have a length
of 1. Defaults to `None`.
Examples:
Generating sentence embeddings:
```python
from distilabel.models import SentenceTransformerEmbeddings
embeddings = SentenceTransformerEmbeddings(model="mixedbread-ai/mxbai-embed-large-v1")
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
# [
# [-0.05447685346007347, -0.01623094454407692, ...],
# [4.4889533455716446e-05, 0.044016145169734955, ...],
# ]
```
"""
model: str
device: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The device to be used to load the model. If `None`, then it"
" will check if a GPU can be used.",
)
prompts: Optional[Dict[str, str]] = None
default_prompt_name: Optional[str] = None
trust_remote_code: bool = False
revision: Optional[str] = None
token: Optional[str] = None
truncate_dim: Optional[int] = None
model_kwargs: Optional[Dict[str, Any]] = None
tokenizer_kwargs: Optional[Dict[str, Any]] = None
config_kwargs: Optional[Dict[str, Any]] = None
precision: Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]] = (
"float32"
)
normalize_embeddings: RuntimeParameter[bool] = Field(
default=True,
description="Whether to normalize the embeddings so the generated vectors"
" have a length of 1 or not.",
)
_model: Union["SentenceTransformer", None] = PrivateAttr(None)
def load(self) -> None:
"""Loads the Sentence Transformer model"""
super().load()
if self.device == "cuda":
CudaDevicePlacementMixin.load(self)
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise ImportError(
"`sentence-transformers` package is not installed. Please install it using"
" `pip install 'distilabel[sentence-transformers]'`."
) from e
self._model = SentenceTransformer(
model_name_or_path=self.model,
device=self.device,
prompts=self.prompts,
default_prompt_name=self.default_prompt_name,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
token=self.token,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
)
@property
def model_name(self) -> str:
"""Returns the name of the model."""
return self.model
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
return self._model.encode( # type: ignore
sentences=inputs,
batch_size=len(inputs),
convert_to_numpy=True,
precision=self.precision, # type: ignore
normalize_embeddings=self.normalize_embeddings, # type: ignore
).tolist() # type: ignore
def unload(self) -> None:
del self._model
if self.device == "cuda":
CudaDevicePlacementMixin.unload(self)
super().unload()
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import Field, PrivateAttr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
if TYPE_CHECKING:
from vllm import LLM as _vLLM
class vLLMEmbeddings(Embeddings, CudaDevicePlacementMixin):
"""`vllm` library implementation for embedding generation.
Attributes:
model: the model Hugging Face Hub repo id or a path to a directory containing the
model weights and configuration files.
dtype: the data type to use for the model. Defaults to `auto`.
trust_remote_code: whether to trust the remote code when loading the model. Defaults
to `False`.
quantization: the quantization mode to use for the model. Defaults to `None`.
revision: the revision of the model to load. Defaults to `None`.
enforce_eager: whether to enforce eager execution. Defaults to `True`.
seed: the seed to use for the random number generator. Defaults to `0`.
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
`LLM` class of `vllm` library. Defaults to `{}`.
_model: the `vLLM` model instance. This attribute is meant to be used internally
and should not be accessed directly. It will be set in the `load` method.
References:
- [Offline inference embeddings](https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_embedding.html)
Examples:
Generating sentence embeddings:
```python
from distilabel.models import vLLMEmbeddings
embeddings = vLLMEmbeddings(model="intfloat/e5-mistral-7b-instruct")
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
# [
# [-0.05447685346007347, -0.01623094454407692, ...],
# [4.4889533455716446e-05, 0.044016145169734955, ...],
# ]
```
"""
model: str
dtype: str = "auto"
trust_remote_code: bool = False
quantization: Optional[str] = None
revision: Optional[str] = None
enforce_eager: bool = True
seed: int = 0
extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
default_factory=dict,
description="Additional dictionary of keyword arguments that will be passed to the"
" `vLLM` class of `vllm` library. See all the supported arguments at: "
"https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py",
)
_model: "_vLLM" = PrivateAttr(None)
def load(self) -> None:
"""Loads the `vLLM` model using either the path or the Hugging Face Hub repository id."""
super().load()
CudaDevicePlacementMixin.load(self)
try:
from vllm import LLM as _vLLM
except ImportError as ie:
raise ImportError(
"vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`."
) from ie
self._model = _vLLM(
self.model,
dtype=self.dtype,
trust_remote_code=self.trust_remote_code,
quantization=self.quantization,
revision=self.revision,
enforce_eager=self.enforce_eager,
seed=self.seed,
**self.extra_kwargs, # type: ignore
)
def unload(self) -> None:
"""Unloads the `vLLM` model."""
CudaDevicePlacementMixin.unload(self)
super().unload()
@property
def model_name(self) -> str:
"""Returns the name of the model."""
return self.model
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
return [output.outputs.embedding for output in self._model.encode(inputs)]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.image_generation.base import (
AsyncImageGenerationModel,
ImageGenerationModel,
)
from distilabel.models.image_generation.huggingface.inference_endpoints import (
InferenceEndpointsImageGeneration,
)
from distilabel.models.image_generation.openai import OpenAIImageGeneration
__all__ = [
"AsyncImageGenerationModel",
"ImageGenerationModel",
"InferenceEndpointsImageGeneration",
"OpenAIImageGeneration",
]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import logging
import sys
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersModelMixin,
)
from distilabel.utils.docstring import parse_google_docstring
from distilabel.utils.itertools import grouper
from distilabel.utils.serialization import _Serializable
if TYPE_CHECKING:
from logging import Logger
from distilabel.utils.docstring import Docstring
class ImageGenerationModel(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC):
"""Base class for `ImageGeneration` models.
To implement an `ImageGeneration` subclass, you need to subclass this class and implement:
- `load` method to load the `ImageGeneration` model if needed. Don't forget to call `super().load()`,
so the `_logger` attribute is initialized.
- `model_name` property to return the model name used for the LLM.
- `generate` method to generate `num_generations` per input in `inputs`.
Attributes:
generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
methods within each `ImageGenerationModel`.
_logger: the logger to be used for the `ImageGenerationModel`. It will be initialized
when the `load` method is called.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
protected_namespaces=(),
validate_default=True,
validate_assignment=True,
extra="forbid",
)
generation_kwargs: Optional[RuntimeParameter[dict[str, Any]]] = Field(
default_factory=dict,
description="The kwargs to be propagated to either `generate` or `agenerate`"
" methods within each `ImageGenerationModel`.",
)
_logger: "Logger" = PrivateAttr(None)
def load(self) -> None:
"""Method to be called to initialize the `ImageGenerationModel`, and its logger."""
self._logger = logging.getLogger(
f"distilabel.models.image_generation.{self.model_name}"
)
def unload(self) -> None:
"""Method to be called to unload the `ImageGenerationModel` and release any resources."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Returns the model name used for the `ImageGenerationModel`."""
pass
def get_generation_kwargs(self) -> dict[str, Any]:
"""Returns the generation kwargs to be used for the generation. This method can
be overridden to provide a more complex logic for the generation kwargs.
Returns:
The kwargs to be used for the generation.
"""
return self.generation_kwargs # type: ignore
@abstractmethod
def generate(
self, inputs: list[str], num_generations: int = 1, **kwargs: Any
) -> list[list[dict[str, Any]]]:
"""Generates images from the provided input.
Args:
inputs: the prompt text to generate the image from.
num_generations: the number of images to generate. Defaults to `1`.
Returns:
A list with a dictionary with the list of images generated.
"""
pass
def generate_outputs(
self,
inputs: list[str],
num_generations: int = 1,
**kwargs: Any,
) -> list[list[dict[str, Any]]]:
"""This method is defined for compatibility with the `LLMs`. It calls the `generate`
method.
"""
return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)
class AsyncImageGenerationModel(ImageGenerationModel):
"""Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities
of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the
method `agenerate` needs to be implemented to provide the asynchronous generation of
responses.
Attributes:
_event_loop: the event loop to be used for the asynchronous generation of responses.
"""
_num_generations_param_supported = True
_event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
_new_event_loop: bool = PrivateAttr(default=False)
@property
def generate_parameters(self) -> list[inspect.Parameter]:
"""Returns the parameters of the `agenerate` method.
Returns:
A list containing the parameters of the `agenerate` method.
"""
return list(inspect.signature(self.agenerate).parameters.values())
@cached_property
def generate_parsed_docstring(self) -> "Docstring":
"""Returns the parsed docstring of the `agenerate` method.
Returns:
The parsed docstring of the `agenerate` method.
"""
return parse_google_docstring(self.agenerate)
@property
def event_loop(self) -> "asyncio.AbstractEventLoop":
if self._event_loop is None:
try:
self._event_loop = asyncio.get_running_loop()
if self._event_loop.is_closed():
self._event_loop = asyncio.new_event_loop() # type: ignore
self._new_event_loop = True
except RuntimeError:
self._event_loop = asyncio.new_event_loop()
self._new_event_loop = True
asyncio.set_event_loop(self._event_loop)
return self._event_loop
@abstractmethod
async def agenerate(
self, input: str, num_generations: int = 1, **kwargs: Any
) -> list[dict[str, Any]]:
"""Generates images from the provided input.
Args:
input: the input text to generate the image from.
num_generations: the number of images to generate. Defaults to `1`.
Returns:
A list with a dictionary with the list of images generated.
"""
pass
async def _agenerate(
self, inputs: list[str], num_generations: int = 1, **kwargs: Any
) -> list[list[dict[str, Any]]]:
"""Internal function to concurrently generate images for a list of inputs.
Args:
inputs: the list of inputs to generate images for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the generations for each input.
"""
if self._num_generations_param_supported:
tasks = [
asyncio.create_task(
self.agenerate(
input=input, num_generations=num_generations, **kwargs
)
)
for input in inputs
]
return await asyncio.gather(*tasks)
tasks = [
asyncio.create_task(self.agenerate(input=input, **kwargs))
for input in inputs
for _ in range(num_generations)
]
outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
return [
list(group)
for group in grouper(outputs, n=num_generations, incomplete="ignore")
]
def generate(
self,
inputs: list[str],
num_generations: int = 1,
**kwargs: Any,
) -> list[list[dict[str, Any]]]:
"""Method to generate a list of images asynchronously, returning the output
synchronously awaiting for the image of each input sent to `agenerate`.
Args:
inputs: the list of inputs to generate images for.
num_generations: the number of generations to generate per input.
**kwargs: the additional kwargs to be used for the generation.
Returns:
A list containing the images for each input.
"""
return self.event_loop.run_until_complete(
self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
)
def __del__(self) -> None:
"""Closes the event loop when the object is deleted."""
if sys.meta_path is None:
return
if self._new_event_loop:
if self._event_loop.is_running():
self._event_loop.stop()
self._event_loop.close()
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from pydantic import validate_call
from distilabel.models.base_clients.inference_endpoints import (
InferenceEndpointsBaseClient,
)
from distilabel.models.image_generation.base import AsyncImageGenerationModel
if TYPE_CHECKING:
from PIL.Image import Image
class InferenceEndpointsImageGeneration( # type: ignore
InferenceEndpointsBaseClient, AsyncImageGenerationModel
):
"""Inference Endpoint image generation implementation running the async API client.
Attributes:
model_id: the model ID to use for the ImageGenerationModel as available in the Hugging Face Hub, which
will be used to resolve the base URL for the serverless Inference Endpoints API requests.
Defaults to `None`.
endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
base_url: the base URL to use for the Inference Endpoints API requests.
api_key: the API key to authenticate the requests to the Inference Endpoints API.
Icon:
`:hugging:`
Examples:
Generate images from text prompts:
```python
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key")
igm.load()
output = igm.generate_outputs(
inputs=["a white siamese cat"],
)
# [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
```
"""
def load(self) -> None:
from distilabel.models.image_generation.utils import image_to_str
# Sets the logger and calls the load method of the BaseClient
AsyncImageGenerationModel.load(self)
InferenceEndpointsBaseClient.load(self)
self._image_to_str = image_to_str
@validate_call
async def agenerate( # type: ignore
self,
input: str,
negative_prompt: Optional[str] = None,
height: Optional[float] = None,
width: Optional[float] = None,
num_inference_steps: Optional[float] = None,
guidance_scale: Optional[float] = None,
num_generations: int = 1,
) -> list[dict[str, Any]]:
"""Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`.
Args:
input: Prompt to generate an image from.
negative_prompt: An optional negative prompt for the image generation. Defaults to None.
height: The height in pixels of the image to generate.
width: The width in pixels of the image to generate.
num_inference_steps: The number of denoising steps. More denoising steps usually lead
to a higher quality image at the expense of slower inference.
guidance_scale: Higher guidance scale encourages to generate images that are closely
linked to the text `prompt`, usually at the expense of lower image quality.
num_generations: The number of images to generate. Defaults to `1`.
It's here to ensure the validation succeeds, but it won't have effect.
Returns:
A list with a dictionary containing a list with the image as a base64 string.
"""
image: "Image" = await self._aclient.text_to_image( # type: ignore
input,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
img_str = self._image_to_str(image, image_format="JPEG")
return [{"images": [img_str]}]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from typing import TYPE_CHECKING, Any, Literal, Optional
import requests
from pydantic import validate_call
from distilabel.models.base_clients.openai import OpenAIBaseClient
from distilabel.models.image_generation.base import AsyncImageGenerationModel
if TYPE_CHECKING:
from openai.types import ImagesResponse
class OpenAIImageGeneration(OpenAIBaseClient, AsyncImageGenerationModel):
"""OpenAI image generation implementation running the async API client.
Attributes:
model: the model name to use for the ImageGenerationModel e.g. "dall-e-3", etc.
Supported models can be found [here](https://platform.openai.com/docs/guides/images).
base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
means that the value set for the environment variable `OPENAI_BASE_URL` will
be used, or "https://api.openai.com/v1" if not set.
api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
`None` which means that the value set for the environment variable `OPENAI_API_KEY`
will be used, or `None` if not set.
max_retries: the maximum number of times to retry the request to the API before
failing. Defaults to `6`.
timeout: the maximum time in seconds to wait for a response from the API. Defaults
to `120`.
Icon:
`:simple-openai:`
Examples:
Generate images from text prompts:
```python
from distilabel.models.image_generation import OpenAIImageGeneration
igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key")
igm.load()
output = igm.generate_outputs(
inputs=["a white siamese cat"],
size="1024x1024",
quality="standard",
style="natural",
)
# [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
```
"""
def load(self) -> None:
# Sets the logger and calls the load method of the BaseClient
AsyncImageGenerationModel.load(self)
OpenAIBaseClient.load(self)
@validate_call
async def agenerate( # type: ignore
self,
input: str,
num_generations: int = 1,
quality: Optional[Literal["standard", "hd"]] = "standard",
response_format: Optional[Literal["url", "b64_json"]] = "url",
size: Optional[
Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
] = None,
style: Optional[Literal["vivid", "natural"]] = None,
) -> list[dict[str, Any]]:
"""Generates `num_generations` images for the given input using the OpenAI async
client. The images are base64 string representations.
Args:
input: A text description of the desired image(s). The maximum length is 1000
characters for `dall-e-2` and 4000 characters for `dall-e-3`.
num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
`n=1` is supported.
quality: The quality of the image that will be generated. `hd` creates images with finer
details and greater consistency across the image. This param is only supported
for `dall-e-3`.
response_format: The format in which the generated images are returned. Must be one of `url` or
`b64_json`. URLs are only valid for 60 minutes after the image has been
generated.
size: The size of the generated images. Must be one of `256x256`, `512x512`, or
`1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
`1024x1792` for `dall-e-3` models.
style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid
causes the model to lean towards generating hyper-real and dramatic images.
Natural causes the model to produce more natural, less hyper-real looking
images. This param is only supported for `dall-e-3`.
Returns:
A list with a dictionary with the list of images generated.
"""
images_response: "ImagesResponse" = await self._aclient.images.generate(
model=self.model_name,
prompt=input,
n=num_generations,
quality=quality,
response_format=response_format,
size=size,
style=style,
)
images = []
for image in images_response.data:
if response_format == "url":
image_data = requests.get(
image.url
).content # TODO: Keep a requests/httpx session instead
image_str = base64.b64encode(image_data).decode()
images.append(image_str)
elif response_format == "b64_json":
images.append(image.b64_json)
return [{"images": images}]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
from PIL import Image
def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str:
"""Converts a PIL Image to a base64 encoded string."""
buffered = io.BytesIO()
image.save(buffered, format=image_format)
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def image_from_str(image_str: str) -> "Image.Image":
"""Converts a base64 encoded string to a PIL Image."""
image_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(image_bytes))
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.llms.anthropic import AnthropicLLM
from distilabel.models.llms.anyscale import AnyscaleLLM
from distilabel.models.llms.azure import AzureOpenAILLM
from distilabel.models.llms.base import LLM, AsyncLLM
from distilabel.models.llms.cohere import CohereLLM
from distilabel.models.llms.groq import GroqLLM
from distilabel.models.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
from distilabel.models.llms.litellm import LiteLLM
from distilabel.models.llms.llamacpp import LlamaCppLLM
from distilabel.models.llms.mistral import MistralLLM
from distilabel.models.llms.mlx import MlxLLM
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.together import TogetherLLM
from distilabel.models.llms.vertexai import VertexAILLM
from distilabel.models.llms.vllm import ClientvLLM, vLLM
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.typing import GenerateOutput, HiddenState
__all__ = [
"LLM",
"AnthropicLLM",
"AnyscaleLLM",
"AsyncLLM",
"AzureOpenAILLM",
"ClientvLLM",
"CohereLLM",
"CudaDevicePlacementMixin",
"GenerateOutput",
"GroqLLM",
"HiddenState",
"InferenceEndpointsLLM",
"LiteLLM",
"LlamaCppLLM",
"MistralLLM",
"MixtureOfAgentsLLM",
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"TogetherLLM",
"TransformersLLM",
"VertexAILLM",
"vLLM",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment