"vscode:/vscode.git/clone" did not exist on "1f72865726f7f8ca7d0202bb8cd2e487394f8c83"
Unverified Commit 40ed18ae authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add an option to log result from the Agent (#23454)

parent f69589d1
......@@ -207,6 +207,7 @@ class Agent:
self.chat_prompt_template = CHAT_PROMPT_TEMPLATE if chat_prompt_template is None else chat_prompt_template
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
self.log = print
if additional_tools is not None:
if isinstance(additional_tools, (list, tuple)):
additional_tools = {t.name: t for t in additional_tools}
......@@ -244,6 +245,15 @@ class Agent:
prompt = prompt.replace("<<prompt>>", task)
return prompt
def set_stream(self, streamer):
"""
Set the function use to stream results (which is `print` by default).
Args:
streamer (`callable`): The function to call when streaming results from the LLM.
"""
self.log = streamer
def chat(self, task, *, return_code=False, remote=False, **kwargs):
"""
Sends a new request to the agent in a chat. Will use the previous ones in its history.
......@@ -273,12 +283,12 @@ class Agent:
self.chat_history = prompt + result.strip() + "\n"
explanation, code = clean_code_for_chat(result)
print(f"==Explanation from the agent==\n{explanation}")
self.log(f"==Explanation from the agent==\n{explanation}")
if code is not None:
print(f"\n\n==Code generated by the agent==\n{code}")
self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
self.chat_state.update(kwargs)
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
......@@ -320,11 +330,11 @@ class Agent:
result = self.generate_one(prompt, stop=["Task:"])
explanation, code = clean_code_for_run(result)
print(f"==Explanation from the agent==\n{explanation}")
self.log(f"==Explanation from the agent==\n{explanation}")
print(f"\n\n==Code generated by the agent==\n{code}")
self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy())
else:
......@@ -487,7 +497,7 @@ class HfAgent(Agent):
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
logger.info("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment