Unverified Commit cf432008 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add local agent (#23438)

* Add local agent

* Document LocalAgent
parent db136341
...@@ -24,12 +24,16 @@ contains the API docs for the underlying classes. ...@@ -24,12 +24,16 @@ contains the API docs for the underlying classes.
## Agents ## Agents
We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models. We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models.
### HfAgent ### HfAgent
[[autodoc]] HfAgent [[autodoc]] HfAgent
### LocalAgent
[[autodoc]] LocalAgent
### OpenAiAgent ### OpenAiAgent
[[autodoc]] OpenAiAgent [[autodoc]] OpenAiAgent
......
...@@ -614,6 +614,7 @@ _import_structure = { ...@@ -614,6 +614,7 @@ _import_structure = {
"tools": [ "tools": [
"Agent", "Agent",
"HfAgent", "HfAgent",
"LocalAgent",
"OpenAiAgent", "OpenAiAgent",
"PipelineTool", "PipelineTool",
"RemoteTool", "RemoteTool",
...@@ -4361,7 +4362,17 @@ if TYPE_CHECKING: ...@@ -4361,7 +4362,17 @@ if TYPE_CHECKING:
) )
# Tools # Tools
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool from .tools import (
Agent,
HfAgent,
LocalAgent,
OpenAiAgent,
PipelineTool,
RemoteTool,
Tool,
launch_gradio_demo,
load_tool,
)
# Trainer # Trainer
from .trainer_callback import ( from .trainer_callback import (
......
...@@ -24,7 +24,7 @@ from ..utils import ( ...@@ -24,7 +24,7 @@ from ..utils import (
_import_structure = { _import_structure = {
"agents": ["Agent", "HfAgent", "OpenAiAgent"], "agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"], "base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
} }
...@@ -46,7 +46,7 @@ else: ...@@ -46,7 +46,7 @@ else:
_import_structure["translation"] = ["TranslationTool"] _import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, HfAgent, OpenAiAgent from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
try: try:
......
...@@ -24,6 +24,8 @@ from typing import Dict ...@@ -24,6 +24,8 @@ from typing import Dict
import requests import requests
from huggingface_hub import HfFolder, hf_hub_download, list_spaces from huggingface_hub import HfFolder, hf_hub_download, list_spaces
from ..generation import StoppingCriteria, StoppingCriteriaList
from ..models.auto import AutoModelForCausalLM, AutoTokenizer
from ..utils import is_openai_available, logging from ..utils import is_openai_available, logging
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
...@@ -492,3 +494,114 @@ class HfAgent(Agent): ...@@ -492,3 +494,114 @@ class HfAgent(Agent):
if result.endswith(stop_seq): if result.endswith(stop_seq):
return result[: -len(stop_seq)] return result[: -len(stop_seq)]
return result return result
class LocalAgent(Agent):
"""
Agent that uses a local model and tokenizer to generate code.
Args:
model ([`PreTrainedModel`]):
The model to use for the agent.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer to use for the agent.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
checkpoint = "bigcode/starcoder"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
agent = LocalAgent(model, tokenizer)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
self.model = model
self.tokenizer = tokenizer
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
kwargs:
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
Example:
```py
import torch
from transformers import LocalAgent
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, tokenizer)
@property
def _model_device(self):
if hasattr(self.model, "hf_device_map"):
return list(self.model.hf_device_map.values())[0]
for param in self.mode.parameters():
return param.device
def generate_one(self, prompt, stop):
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
src_len = encoded_inputs["input_ids"].shape[1]
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
outputs = self.model.generate(
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
)
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result
class StopSequenceCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever a sequence of tokens is encountered.
Args:
stop_sequences (`str` or `List[str]`):
The sequence (or list of sequences) on which to stop execution.
tokenizer:
The tokenizer used to decode the model outputs.
"""
def __init__(self, stop_sequences, tokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment