Unverified Commit c3572e6b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add AzureOpenAiAgent (#24058)



* Add AzureOpenAiAgent

* quality

* Update src/transformers/tools/agents.py
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>

---------
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
parent 5eb3d3c7
...@@ -38,6 +38,10 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens ...@@ -38,6 +38,10 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens
[[autodoc]] OpenAiAgent [[autodoc]] OpenAiAgent
### AzureOpenAiAgent
[[autodoc]] AzureOpenAiAgent
### Agent ### Agent
[[autodoc]] Agent [[autodoc]] Agent
......
...@@ -619,6 +619,7 @@ _import_structure = { ...@@ -619,6 +619,7 @@ _import_structure = {
], ],
"tools": [ "tools": [
"Agent", "Agent",
"AzureOpenAiAgent",
"HfAgent", "HfAgent",
"LocalAgent", "LocalAgent",
"OpenAiAgent", "OpenAiAgent",
...@@ -4410,6 +4411,7 @@ if TYPE_CHECKING: ...@@ -4410,6 +4411,7 @@ if TYPE_CHECKING:
# Tools # Tools
from .tools import ( from .tools import (
Agent, Agent,
AzureOpenAiAgent,
HfAgent, HfAgent,
LocalAgent, LocalAgent,
OpenAiAgent, OpenAiAgent,
......
...@@ -24,7 +24,7 @@ from ..utils import ( ...@@ -24,7 +24,7 @@ from ..utils import (
_import_structure = { _import_structure = {
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"], "agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"], "base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
} }
...@@ -46,7 +46,7 @@ else: ...@@ -46,7 +46,7 @@ else:
_import_structure["translation"] = ["TranslationTool"] _import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
try: try:
......
...@@ -453,6 +453,132 @@ class OpenAiAgent(Agent): ...@@ -453,6 +453,132 @@ class OpenAiAgent(Agent):
return [answer["text"] for answer in result["choices"]] return [answer["text"] for answer in result["choices"]]
class AzureOpenAiAgent(Agent):
"""
Agent that uses Azure OpenAI to generate code. See the [official
documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
model on Azure
<Tip warning={true}>
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
</Tip>
Args:
deployment_id (`str`):
The name of the deployed Azure openAI model to use.
api_key (`str`, *optional*):
The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
resource_name (`str`, *optional*):
The name of your Azure OpenAI Resource. If unset, will look for the environment variable
`"AZURE_OPENAI_RESOURCE_NAME"`.
api_version (`str`, *optional*, default to `"2022-12-01"`):
The API version to use for this agent.
is_chat_mode (`bool`, *optional*):
Whether you are using a completion model or a chat model (see note above, chat models won't be as
efficient). Will default to `gpt` being in the `deployment_id` or not.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import AzureOpenAiAgent
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self,
deployment_id,
api_key=None,
resource_name=None,
api_version="2022-12-01",
is_chat_model=None,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
self.deployment_id = deployment_id
openai.api_type = "azure"
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
"`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
)
else:
openai.api_key = api_key
if resource_name is None:
resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None)
if resource_name is None:
raise ValueError(
"You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
"`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
)
else:
openai.api_base = f"https://{resource_name}.openai.azure.com"
openai.api_version = api_version
if is_chat_model is None:
is_chat_model = "gpt" in deployment_id.lower()
self.is_chat_model = is_chat_model
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_many(self, prompts, stop):
if self.is_chat_model:
return [self._chat_generate(prompt, stop) for prompt in prompts]
else:
return self._completion_generate(prompts, stop)
def generate_one(self, prompt, stop):
if self.is_chat_model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]
def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create(
engine=self.deployment_id,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result["choices"][0]["message"]["content"]
def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
engine=self.deployment_id,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]
class HfAgent(Agent): class HfAgent(Agent):
""" """
Agent that uses an inference endpoint to generate code. Agent that uses an inference endpoint to generate code.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment