Commit 4d4d8f59 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2741 canceled with stages
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: E402
import warnings
deprecation_message = (
"Importing from 'distilabel.llms' is deprecated and will be removed in a version 1.7.0. "
"Import from 'distilabel.models' instead."
)
warnings.warn(deprecation_message, DeprecationWarning, stacklevel=2)
from distilabel.models.llms.anthropic import AnthropicLLM
from distilabel.models.llms.anyscale import AnyscaleLLM
from distilabel.models.llms.azure import AzureOpenAILLM
from distilabel.models.llms.base import LLM, AsyncLLM
from distilabel.models.llms.cohere import CohereLLM
from distilabel.models.llms.groq import GroqLLM
from distilabel.models.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
from distilabel.models.llms.litellm import LiteLLM
from distilabel.models.llms.llamacpp import LlamaCppLLM
from distilabel.models.llms.mistral import MistralLLM
from distilabel.models.llms.mlx import MlxLLM
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.together import TogetherLLM
from distilabel.models.llms.vertexai import VertexAILLM
from distilabel.models.llms.vllm import ClientvLLM, vLLM
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.typing import GenerateOutput, HiddenState
__all__ = [
"LLM",
"AnthropicLLM",
"AnyscaleLLM",
"AsyncLLM",
"AzureOpenAILLM",
"ClientvLLM",
"CohereLLM",
"CudaDevicePlacementMixin",
"GenerateOutput",
"GroqLLM",
"HiddenState",
"InferenceEndpointsLLM",
"LiteLLM",
"LlamaCppLLM",
"MistralLLM",
"MixtureOfAgentsLLM",
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"TogetherLLM",
"TransformersLLM",
"VertexAILLM",
"vLLM",
]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
from importlib.metadata import version
from typing import List, Union
from packaging.requirements import InvalidRequirement, Requirement
class RequirementsMixin:
"""Mixin for classes that have `requirements` attribute.
Used to add requirements to a `Step` and a `Pipeline`.
"""
_requirements: Union[List[Requirement], None] = []
def _gather_requirements(self) -> List[str]:
"""This method will be overwritten in the `BasePipeline` class to gather the requirements
from each step.
"""
return []
@property
def requirements(self) -> List[str]:
"""Return a list of requirements that must be installed to run the `Pipeline`.
The requirements in a Pipeline will include the requirements from all the steps (if any).
Returns:
List of requirements that must be installed to run the `Pipeline`, sorted alphabetically.
"""
self.requirements = self._gather_requirements()
return [str(r) for r in self._requirements]
@requirements.setter
def requirements(self, _requirements: List[str]) -> None:
requirements = []
if not isinstance(_requirements, list):
_requirements = [_requirements]
for r in _requirements:
try:
requirements.append(Requirement(r))
except InvalidRequirement:
self._logger.warning(f"Invalid requirement: `{r}`")
self._requirements = sorted(
set(self._requirements).union(set(requirements)), key=lambda x: str(x)
)
def requirements_to_install(self) -> List[str]:
"""Check if the requirements are installed in the current environment, and returns the ones that aren't.
Returns:
List of requirements required to run the pipeline that are not installed in the current environment.
"""
to_install = []
for req in self.requirements:
requirement = Requirement(req)
if importlib.util.find_spec(requirement.name):
if (str(requirement.specifier) != "") and (
version(requirement.name) != str(requirement.specifier)
):
to_install.append(req)
else:
to_install.append(req)
return to_install
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import difflib
import inspect
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, TypeVar, Union
from pydantic import BaseModel, Field, PrivateAttr
from typing_extensions import Annotated, get_args, get_origin
from distilabel.utils.docstring import parse_google_docstring
from distilabel.utils.typing_ import (
extract_annotation_inner_type,
is_type_pydantic_secret_field,
)
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from distilabel.utils.docstring import Docstring
_T = TypeVar("_T")
_RUNTIME_PARAMETER_ANNOTATION = "distilabel_step_runtime_parameter"
RuntimeParameter = Annotated[
Union[_T, None], Field(default=None), _RUNTIME_PARAMETER_ANNOTATION
]
"""Used to mark the attributes of a `Step` as a runtime parameter."""
RuntimeParametersNames = Dict[str, Union[bool, "RuntimeParametersNames"]]
"""Alias for the names of the runtime parameters of a `Step`."""
RuntimeParameterInfo = Dict[str, Any]
"""Alias for the information of the runtime parameters of a `Step`."""
class RuntimeParametersMixin(BaseModel):
"""Mixin for classes that have `RuntimeParameter`s attributes.
Attributes:
_runtime_parameters: A dictionary containing the values of the runtime parameters
of the class. This attribute is meant to be used internally and should not be
accessed directly.
"""
_runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict)
@property
def runtime_parameters_names(self) -> "RuntimeParametersNames":
"""Returns a dictionary containing the name of the runtime parameters of the class
as keys and whether the parameter is required or not as values.
Returns:
A dictionary containing the name of the runtime parameters of the class as keys
and whether the parameter is required or not as values.
"""
runtime_parameters = {}
for name, field_info in self.model_fields.items(): # type: ignore
# `field: RuntimeParameter[Any]` or `field: Optional[RuntimeParameter[Any]]`
is_runtime_param, is_optional = _is_runtime_parameter(field_info)
if is_runtime_param:
runtime_parameters[name] = is_optional
continue
attr = getattr(self, name)
# `field: RuntimeParametersMixin`
if isinstance(attr, RuntimeParametersMixin):
runtime_parameters[name] = attr.runtime_parameters_names
# `field: List[RuntimeParametersMixin]`
if (
isinstance(attr, list)
and attr
and isinstance(attr[0], RuntimeParametersMixin)
):
runtime_parameters[name] = {
str(i): item.runtime_parameters_names for i, item in enumerate(attr)
}
return runtime_parameters
def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]:
"""Gets the information of the runtime parameters of the class such as the name and
the description. This function is meant to include the information of the runtime
parameters in the serialized data of the class.
Returns:
A list containing the information for each runtime parameter of the class.
"""
runtime_parameters_info = []
for name, field_info in self.model_fields.items(): # type: ignore
if name not in self.runtime_parameters_names:
continue
attr = getattr(self, name)
# Get runtime parameters info for `RuntimeParametersMixin` field
if isinstance(attr, RuntimeParametersMixin):
runtime_parameters_info.append(
{
"name": name,
"runtime_parameters_info": attr.get_runtime_parameters_info(),
}
)
continue
# Get runtime parameters info for `List[RuntimeParametersMixin]` field
if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
runtime_parameters_info.append(
{
"name": name,
"runtime_parameters_info": {
str(i): item.get_runtime_parameters_info()
for i, item in enumerate(attr)
},
}
)
continue
info = {"name": name, "optional": self.runtime_parameters_names[name]}
if field_info.description is not None:
info["description"] = field_info.description
runtime_parameters_info.append(info)
return runtime_parameters_info
def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
"""Sets the runtime parameters of the class using the provided values. If the attr
to be set is a `RuntimeParametersMixin`, it will call `set_runtime_parameters` on
the attr.
Args:
runtime_parameters: A dictionary containing the values of the runtime parameters
to set.
"""
runtime_parameters_names = list(self.runtime_parameters_names.keys())
for name, value in runtime_parameters.items():
if name not in self.runtime_parameters_names:
# Check done just to ensure the unit tests for the mixin run
if getattr(self, "pipeline", None):
closest = difflib.get_close_matches(
name, runtime_parameters_names, cutoff=0.5
)
msg = (
f"⚠️ Runtime parameter '{name}' unknown in step '{self.name}'." # type: ignore
)
if closest:
msg += f" Did you mean any of: {closest}"
else:
msg += f" Available runtime parameters for the step: {runtime_parameters_names}."
self.pipeline._logger.warning(msg) # type: ignore
continue
attr = getattr(self, name)
# Set runtime parameters for `RuntimeParametersMixin` field
if isinstance(attr, RuntimeParametersMixin):
attr.set_runtime_parameters(value)
self._runtime_parameters[name] = value
continue
# Set runtime parameters for `List[RuntimeParametersMixin]` field
if isinstance(attr, list) and isinstance(attr[0], RuntimeParametersMixin):
for i, item in enumerate(attr):
item_value = value.get(str(i), {})
item.set_runtime_parameters(item_value)
self._runtime_parameters[name] = value
continue
# Handle settings values for `_SecretField`
field_info = self.model_fields[name]
inner_type = extract_annotation_inner_type(field_info.annotation)
if is_type_pydantic_secret_field(inner_type):
value = inner_type(value)
# Set the value of the runtime parameter
setattr(self, name, value)
self._runtime_parameters[name] = value
def _is_runtime_parameter(field: "FieldInfo") -> Tuple[bool, bool]:
"""Check if a `pydantic.BaseModel` field is a `RuntimeParameter` and if it's optional
i.e. providing a value for the field in `Pipeline.run` is optional.
Args:
field: The info of the field of the `pydantic.BaseModel` to check.
Returns:
A tuple with two booleans. The first one indicates if the field is a
`RuntimeParameter` or not, and the second one indicates if the field is optional
or not.
"""
# Case 1: `runtime_param: RuntimeParameter[int]`
# Mandatory runtime parameter that needs to be provided when running the pipeline
if _RUNTIME_PARAMETER_ANNOTATION in field.metadata:
return True, field.default is not None
# Case 2: `runtime_param: Union[RuntimeParameter[int], None] = None`
# Optional runtime parameter that doesn't need to be provided when running the pipeline
type_args = get_args(field.annotation)
for arg in type_args:
is_runtime_param = (
get_origin(arg) is Annotated
and get_args(arg)[-1] == _RUNTIME_PARAMETER_ANNOTATION
)
if is_runtime_param:
is_optional = (
get_origin(field.annotation) is Union and type(None) in type_args
)
return True, is_optional
return False, False
class RuntimeParametersModelMixin(RuntimeParametersMixin):
"""Specific mixin for RuntimeParameters that affect the model classes, LLM,
ImageGenerationModel, etc.
"""
@property
def generate_parameters(self) -> list["inspect.Parameter"]:
"""Returns the parameters of the `generate` method.
Returns:
A list containing the parameters of the `generate` method.
"""
return list(inspect.signature(self.generate).parameters.values())
@property
def runtime_parameters_names(self) -> "RuntimeParametersNames":
"""Returns the runtime parameters of the `ImageGenerationModel`, which are combination of the
attributes of the `ImageGenerationModel` type hinted with `RuntimeParameter` and the parameters
of the `generate` method that are not `input` and `num_generations`.
Returns:
A dictionary with the name of the runtime parameters as keys and a boolean
indicating if the parameter is optional or not.
"""
runtime_parameters = super().runtime_parameters_names
runtime_parameters["generation_kwargs"] = {}
# runtime parameters from the `generate` method
for param in self.generate_parameters:
if param.name in ["input", "inputs", "num_generations"]:
continue
is_optional = param.default != inspect.Parameter.empty
runtime_parameters["generation_kwargs"][param.name] = is_optional
return runtime_parameters
def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]:
"""Gets the information of the runtime parameters of the `LLM` such as the name
and the description. This function is meant to include the information of the runtime
parameters in the serialized data of the `LLM`.
Returns:
A list containing the information for each runtime parameter of the `LLM`.
"""
runtime_parameters_info = super().get_runtime_parameters_info()
generation_kwargs_info = next(
(
runtime_parameter_info
for runtime_parameter_info in runtime_parameters_info
if runtime_parameter_info["name"] == "generation_kwargs"
),
None,
)
# If `generation_kwargs` attribute is present, we need to include the `generate`
# method arguments as the information for this attribute.
if generation_kwargs_info:
generate_docstring_args = self.generate_parsed_docstring["args"]
generation_kwargs_info["keys"] = []
for key, value in generation_kwargs_info["optional"].items():
info = {"name": key, "optional": value}
if description := generate_docstring_args.get(key):
info["description"] = description
generation_kwargs_info["keys"].append(info)
generation_kwargs_info.pop("optional")
return runtime_parameters_info
@cached_property
def generate_parsed_docstring(self) -> "Docstring":
"""Returns the parsed docstring of the `generate` method.
Returns:
The parsed docstring of the `generate` method.
"""
return parse_google_docstring(self.generate)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, Any, List, Set
from pydantic import BaseModel, Field
from distilabel.utils.serialization import TYPE_INFO_KEY
if TYPE_CHECKING:
pass
# Add here the name of the attributes that shouldn't be used to generate the signature.
# Attributes from a `BaseModel` that is an attribute from the root class must be prefixed
# with the name of the attribute followed by an underscore. For example, if the attribute
# `jobs_ids` is an attribute from the `llm` attribute of the root class it should be added
# as `llm_jobs_ids`.
_EXCLUDE_FROM_SIGNATURE_DEFAULTS = {
TYPE_INFO_KEY,
"disable_cuda_device_placement",
"input_batch_size",
"gpu_memory_utilization",
"resources",
"exclude_from_signature",
"llm_jobs_ids",
"llm_offline_batch_generation_block_until_done",
}
class SignatureMixin(BaseModel):
"""Mixin for creating a signature (for cache) of the class.
Attributes:
exclude_from_signature: list of attributes to exclude from the signature.
"""
exclude_from_signature: Set[str] = Field(
default=_EXCLUDE_FROM_SIGNATURE_DEFAULTS, exclude=True
)
@property
def signature(self) -> str:
"""Makes a signature (hash) of the class, using its attributes.
Returns:
signature of the class.
"""
def flatten_dump(d: Any, parent_key: str = "", sep: str = "_") -> List:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dump(v, new_key, sep=sep))
elif isinstance(v, list):
if len(v) == 0:
items.append((new_key, ""))
elif isinstance(v[0], (str, float, int, bool)):
items.append((new_key, "-".join(map(str, v))))
else:
for i, x in enumerate(v):
items.extend(flatten_dump(x, f"{new_key}-{i}", sep=sep))
elif new_key not in self.exclude_from_signature:
items.append((new_key, v))
return items
info = []
for name, value in flatten_dump(self.dump()):
info.append(f"{name}-{str(value)}")
return hashlib.sha1("-".join(info).encode()).hexdigest()
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.embeddings.llamacpp import LlamaCppEmbeddings
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
from distilabel.models.embeddings.vllm import vLLMEmbeddings
from distilabel.models.image_generation.base import (
AsyncImageGenerationModel,
ImageGenerationModel,
)
from distilabel.models.image_generation.huggingface.inference_endpoints import (
InferenceEndpointsImageGeneration,
)
from distilabel.models.image_generation.openai import OpenAIImageGeneration
from distilabel.models.llms.anthropic import AnthropicLLM
from distilabel.models.llms.anyscale import AnyscaleLLM
from distilabel.models.llms.azure import AzureOpenAILLM
from distilabel.models.llms.base import LLM, AsyncLLM
from distilabel.models.llms.cohere import CohereLLM
from distilabel.models.llms.groq import GroqLLM
from distilabel.models.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
from distilabel.models.llms.litellm import LiteLLM
from distilabel.models.llms.llamacpp import LlamaCppLLM
from distilabel.models.llms.mistral import MistralLLM
from distilabel.models.llms.mlx import MlxLLM
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.together import TogetherLLM
from distilabel.models.llms.vertexai import VertexAILLM
from distilabel.models.llms.vllm import ClientvLLM, vLLM
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.typing import GenerateOutput, HiddenState
__all__ = [
"LLM",
"AnthropicLLM",
"AnyscaleLLM",
"AsyncImageGenerationModel",
"AsyncLLM",
"AzureOpenAILLM",
"ClientvLLM",
"CohereLLM",
"CudaDevicePlacementMixin",
"Embeddings",
"GenerateOutput",
"GroqLLM",
"HiddenState",
"ImageGenerationModel",
"InferenceEndpointsImageGeneration",
"InferenceEndpointsLLM",
"LiteLLM",
"LlamaCppEmbeddings",
"LlamaCppLLM",
"MistralLLM",
"MixtureOfAgentsLLM",
"MlxLLM",
"OllamaLLM",
"OpenAIImageGeneration",
"OpenAILLM",
"SentenceTransformerEmbeddings",
"TogetherLLM",
"TransformersLLM",
"VertexAILLM",
"vLLM",
"vLLMEmbeddings",
]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.base_clients.inference_endpoints import (
InferenceEndpointsBaseClient,
)
from distilabel.models.base_clients.openai import OpenAIBaseClient
__all__ = ["InferenceEndpointsBaseClient", "OpenAIBaseClient"]
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import (
TYPE_CHECKING,
Optional,
)
from pydantic import (
BaseModel,
Field,
PrivateAttr,
SecretStr,
)
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.typing import StructuredOutputType
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR, get_hf_token
if TYPE_CHECKING:
from huggingface_hub import AsyncInferenceClient
from transformers import PreTrainedTokenizer
class InferenceEndpointsBaseClient(BaseModel):
model_id: Optional[str] = None
endpoint_name: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The name of the Inference Endpoint to use for the LLM.",
)
endpoint_namespace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The namespace of the Inference Endpoint to use for the LLM.",
)
base_url: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The base URL to use for the Inference Endpoints API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR),
description="The API key to authenticate the requests to the Inference Endpoints API.",
)
tokenizer_id: Optional[str] = None
model_display_name: Optional[str] = None
structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)
_num_generations_param_supported = False
_model_name: Optional[str] = PrivateAttr(default=None)
_tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None)
_api_key_env_var: str = PrivateAttr(HF_TOKEN_ENV_VAR)
_aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...)
def load(self) -> None: # noqa: C901
"""Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference
Endpoint.
Raises:
ImportError: if the `huggingface-hub` Python client is not installed.
ValueError: if the model is not currently deployed or is not running the TGI framework.
ImportError: if the `transformers` Python client is not installed.
"""
try:
from huggingface_hub import (
AsyncInferenceClient,
InferenceClient,
get_inference_endpoint,
)
except ImportError as ie:
raise ImportError(
"Hugging Face Hub Python client is not installed. Please install it using"
" `pip install 'distilabel[hf-inference-endpoints]'`."
) from ie
if self.api_key is None:
self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key"))
if self.model_id is not None:
client = InferenceClient(
model=self.model_id, token=self.api_key.get_secret_value()
)
status = client.get_model_status()
if (
status.state not in {"Loadable", "Loaded"}
and status.framework != "text-generation-inference"
):
raise ValueError(
f"Model {self.model_id} is not currently deployed or is not running the TGI framework"
)
self.base_url = client._resolve_url(
model=self.model_id, task="text-generation"
)
if self.endpoint_name is not None:
client = get_inference_endpoint(
name=self.endpoint_name,
namespace=self.endpoint_namespace,
token=self.api_key.get_secret_value(),
)
if client.status in ["paused", "scaledToZero"]:
client.resume().wait(timeout=300)
elif client.status == "initializing":
client.wait(timeout=300)
self.base_url = client.url
self._model_name = client.repository
self._aclient = AsyncInferenceClient(
base_url=self.base_url,
token=self.api_key.get_secret_value(),
)
if self.tokenizer_id:
try:
from transformers import AutoTokenizer
except ImportError as ie:
raise ImportError(
"Transformers Python client is not installed. Please install it using"
" `pip install 'distilabel[hf-inference-endpoints]'`."
) from ie
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
@property
def model_name(self) -> str:
"""Returns the model name used for the model."""
return ( # type: ignore
self.model_display_name
or self._model_name
or self.model_id
or self.endpoint_name
or self.base_url
)
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Dict, Optional
from pydantic import BaseModel, Field, PrivateAttr, SecretStr
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.typing import InstructorStructuredOutputType
if TYPE_CHECKING:
from openai import AsyncOpenAI, OpenAI
_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY"
class OpenAIBaseClient(BaseModel):
model: str
base_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv(
"OPENAI_BASE_URL", "https://api.openai.com/v1"
),
description="The base URL to use for the OpenAI API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the OpenAI API.",
) # type: ignore
default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field(
default=None,
description="The default headers to use for the OpenAI API requests.",
)
max_retries: RuntimeParameter[int] = Field(
default=6,
description="The maximum number of times to retry the request to the API before"
" failing.",
)
timeout: RuntimeParameter[int] = Field(
default=120,
description="The maximum time in seconds to wait for a response from the API.",
)
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
Field(
default=None,
description="The structured output format to use across all the generations.",
)
)
_api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME)
_client: "OpenAI" = PrivateAttr(None) # type: ignore
_aclient: "AsyncOpenAI" = PrivateAttr(None) # type: ignore
def load(self) -> None:
"""Loads the `AsyncOpenAI` client to benefit from async requests."""
try:
from openai import AsyncOpenAI, OpenAI
except ImportError as ie:
raise ImportError(
"OpenAI Python client is not installed. Please install it using"
" `pip install 'distilabel[openai]'`."
) from ie
if self.api_key is None:
raise ValueError(
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)
self._client = OpenAI(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(),
max_retries=self.max_retries, # type: ignore
timeout=self.timeout,
default_headers=self.default_headers,
)
self._aclient = AsyncOpenAI(
base_url=self.base_url,
api_key=self.api_key.get_secret_value(),
max_retries=self.max_retries, # type: ignore
timeout=self.timeout,
default_headers=self.default_headers,
)
if self.structured_output:
# This applies only to the LLMs.
result = self._prepare_structured_output(
structured_output=self.structured_output,
client=self._aclient,
framework="openai",
)
self._aclient = result.get("client") # type: ignore
if structured_output := result.get("structured_output"):
self.structured_output = structured_output
def unload(self) -> None:
"""Set clients to `None` as they both contain `thread._RLock` which cannot be pickled
in case an exception is raised and has to be handled in the main process"""
self._client = None # type: ignore
self._aclient = None # type: ignore
self.default_headers = None
self.structured_output = None
@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.embeddings.llamacpp import LlamaCppEmbeddings
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
from distilabel.models.embeddings.vllm import vLLMEmbeddings
__all__ = [
"Embeddings",
"LlamaCppEmbeddings",
"SentenceTransformerEmbeddings",
"vLLMEmbeddings",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment