Unverified Commit e0d82534 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Agents use grammar (#31735)

* Allow optional use of grammars to constrain generation
parent c54a6f99
...@@ -119,10 +119,12 @@ def llm_engine(messages, stop_sequences=["Task"]) -> str: ...@@ -119,10 +119,12 @@ def llm_engine(messages, stop_sequences=["Task"]) -> str:
``` ```
You could use any `llm_engine` method as long as: You could use any `llm_engine` method as long as:
1. it follows the [messages format](./chat_templating.md) for its input (`List[Dict[str, str]]`) and returns a `str` 1. it follows the [messages format](./chat_templating.md) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`.
2. it stops generating outputs at the sequences passed in the argument `stop` 2. it stops generating outputs at the sequences passed in the argument `stop_sequences`
You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`. Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood. Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.
......
...@@ -328,7 +328,7 @@ class Agent: ...@@ -328,7 +328,7 @@ class Agent:
self, self,
tools: Union[List[Tool], Toolbox], tools: Union[List[Tool], Toolbox],
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt=DEFAULT_REACT_JSON_SYSTEM_PROMPT, system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None, tool_description_template=None,
additional_args={}, additional_args={},
max_iterations: int = 6, max_iterations: int = 6,
...@@ -336,6 +336,7 @@ class Agent: ...@@ -336,6 +336,7 @@ class Agent:
add_base_tools: bool = False, add_base_tools: bool = False,
verbose: int = 0, verbose: int = 0,
memory_verbose: bool = False, memory_verbose: bool = False,
grammar: Dict[str, str] = None,
): ):
self.agent_name = self.__class__.__name__ self.agent_name = self.__class__.__name__
self.llm_engine = llm_engine self.llm_engine = llm_engine
...@@ -347,6 +348,7 @@ class Agent: ...@@ -347,6 +348,7 @@ class Agent:
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.logger = logger self.logger = logger
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.grammar = grammar
if isinstance(tools, Toolbox): if isinstance(tools, Toolbox):
self._toolbox = tools self._toolbox = tools
...@@ -533,6 +535,7 @@ class CodeAgent(Agent): ...@@ -533,6 +535,7 @@ class CodeAgent(Agent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):
...@@ -541,6 +544,7 @@ class CodeAgent(Agent): ...@@ -541,6 +544,7 @@ class CodeAgent(Agent):
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
grammar=grammar,
**kwargs, **kwargs,
) )
...@@ -599,7 +603,9 @@ class CodeAgent(Agent): ...@@ -599,7 +603,9 @@ class CodeAgent(Agent):
self.prompt = [prompt_message, task_message] self.prompt = [prompt_message, task_message]
self.logger.info("====Executing with this prompt====") self.logger.info("====Executing with this prompt====")
self.logger.info(self.prompt) self.logger.info(self.prompt)
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"])
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
if return_generated_code: if return_generated_code:
return llm_output return llm_output
...@@ -652,6 +658,7 @@ class ReactAgent(Agent): ...@@ -652,6 +658,7 @@ class ReactAgent(Agent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0], plan_type: Literal[tuple(SUPPORTED_PLAN_TYPES)] = SUPPORTED_PLAN_TYPES[0],
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
...@@ -662,6 +669,7 @@ class ReactAgent(Agent): ...@@ -662,6 +669,7 @@ class ReactAgent(Agent):
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
grammar=grammar,
**kwargs, **kwargs,
) )
self.planning_interval = planning_interval self.planning_interval = planning_interval
...@@ -881,6 +889,7 @@ class ReactJsonAgent(ReactAgent): ...@@ -881,6 +889,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
...@@ -889,6 +898,7 @@ class ReactJsonAgent(ReactAgent): ...@@ -889,6 +898,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
grammar=grammar,
planning_interval=planning_interval, planning_interval=planning_interval,
**kwargs, **kwargs,
) )
...@@ -912,7 +922,10 @@ class ReactJsonAgent(ReactAgent): ...@@ -912,7 +922,10 @@ class ReactJsonAgent(ReactAgent):
self.logger.info(self.prompt[-1]) self.logger.info(self.prompt[-1])
try: try:
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"]) additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
)
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====") self.logger.debug("===== Output message of the LLM: =====")
...@@ -982,6 +995,7 @@ class ReactCodeAgent(ReactAgent): ...@@ -982,6 +995,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
**kwargs, **kwargs,
...@@ -991,6 +1005,7 @@ class ReactCodeAgent(ReactAgent): ...@@ -991,6 +1005,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
grammar=grammar,
planning_interval=planning_interval, planning_interval=planning_interval,
**kwargs, **kwargs,
) )
...@@ -1028,7 +1043,10 @@ class ReactCodeAgent(ReactAgent): ...@@ -1028,7 +1043,10 @@ class ReactCodeAgent(ReactAgent):
self.logger.info(self.prompt[-2:]) self.logger.info(self.prompt[-2:])
try: try:
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"]) additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(
self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args
)
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from typing import Dict, List from typing import Dict, List, Optional
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
...@@ -66,16 +66,24 @@ llama_role_conversions = { ...@@ -66,16 +66,24 @@ llama_role_conversions = {
class HfEngine: class HfEngine:
def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"): def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
self.model = model self.model = model
self.client = InferenceClient(model=self.model, timeout=120) self.client = InferenceClient(self.model, timeout=120)
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str: def __call__(
self, messages: List[Dict[str, str]], stop_sequences: List[str] = [], grammar: Optional[str] = None
) -> str:
# Get clean message list # Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions) messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
# Get LLM output # Get LLM output
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500) if grammar is not None:
response = self.client.chat_completion(
messages, stop=stop_sequences, max_tokens=1500, response_format=grammar
)
else:
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
response = response.choices[0].message.content response = response.choices[0].message.content
# Remove stop sequences from LLM output # Remove stop sequences from LLM output
...@@ -83,3 +91,14 @@ class HfEngine: ...@@ -83,3 +91,14 @@ class HfEngine:
if response[-len(stop_seq) :] == stop_seq: if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)] response = response[: -len(stop_seq)]
return response return response
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
}
DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_action>",
}
...@@ -63,7 +63,7 @@ Examples: ...@@ -63,7 +63,7 @@ Examples:
--- ---
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French." Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image. Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
Code: Code:
```py ```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English") translated_question = translator(question=question, src_lang="French", tgt_lang="English")
...@@ -75,7 +75,7 @@ final_answer(f"The answer is {answer}") ...@@ -75,7 +75,7 @@ final_answer(f"The answer is {answer}")
--- ---
Task: "Identify the oldest person in the `document` and create an image showcasing the result." Task: "Identify the oldest person in the `document` and create an image showcasing the result."
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. Thought: I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Code: Code:
```py ```py
answer = document_qa(document, question="What is the oldest person?") answer = document_qa(document, question="What is the oldest person?")
...@@ -87,7 +87,7 @@ final_answer(image) ...@@ -87,7 +87,7 @@ final_answer(image)
--- ---
Task: "Generate an image using the text given in the variable `caption`." Task: "Generate an image using the text given in the variable `caption`."
I will use the following tool: `image_generator` to generate an image. Thought: I will use the following tool: `image_generator` to generate an image.
Code: Code:
```py ```py
image = image_generator(prompt=caption) image = image_generator(prompt=caption)
...@@ -97,7 +97,7 @@ final_answer(image) ...@@ -97,7 +97,7 @@ final_answer(image)
--- ---
Task: "Summarize the text given in the variable `text` and read it out loud." Task: "Summarize the text given in the variable `text` and read it out loud."
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud. Thought: I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
Code: Code:
```py ```py
summarized_text = summarizer(text) summarized_text = summarizer(text)
...@@ -109,7 +109,7 @@ final_answer(audio_summary) ...@@ -109,7 +109,7 @@ final_answer(audio_summary)
--- ---
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image." Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer. Thought: I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
Code: Code:
```py ```py
answer = text_qa(text=text, question=question) answer = text_qa(text=text, question=question)
...@@ -121,7 +121,7 @@ final_answer(image) ...@@ -121,7 +121,7 @@ final_answer(image)
--- ---
Task: "Caption the following `image`." Task: "Caption the following `image`."
I will use the following tool: `image_captioner` to generate a caption for the image. Thought: I will use the following tool: `image_captioner` to generate a caption for the image.
Code: Code:
```py ```py
caption = image_captioner(image) caption = image_captioner(image)
...@@ -292,7 +292,6 @@ print(answer) ...@@ -292,7 +292,6 @@ print(answer)
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
Thought: I will now generate an image showcasing the oldest person. Thought: I will now generate an image showcasing the oldest person.
Code: Code:
```py ```py
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.") image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
...@@ -303,7 +302,6 @@ final_answer(image) ...@@ -303,7 +302,6 @@ final_answer(image)
Task: "What is the result of the following operation: 5 + 3 + 1294.678?" Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
Code: Code:
```py ```py
result = 5 + 3 + 1294.678 result = 5 + 3 + 1294.678
......
...@@ -30,7 +30,7 @@ def get_new_path(suffix="") -> str: ...@@ -30,7 +30,7 @@ def get_new_path(suffix="") -> str:
return os.path.join(directory, str(uuid.uuid4()) + suffix) return os.path.join(directory, str(uuid.uuid4()) + suffix)
def fake_react_json_llm(messages, stop_sequences=None) -> str: def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
...@@ -53,7 +53,7 @@ Action: ...@@ -53,7 +53,7 @@ Action:
""" """
def fake_react_code_llm(messages, stop_sequences=None) -> str: def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages) prompt = str(messages)
if "special_marker" not in prompt: if "special_marker" not in prompt:
return """ return """
...@@ -119,7 +119,7 @@ final_answer(res) ...@@ -119,7 +119,7 @@ final_answer(res)
""" """
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str: def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
return """ return """
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
...@@ -130,7 +130,7 @@ final_answer(result) ...@@ -130,7 +130,7 @@ final_answer(result)
""" """
def fake_code_llm_no_return(messages, stop_sequences=None) -> str: def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str:
return """ return """
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment