"torchvision/extension.py" did not exist on "579eebea61ba3f1ed4f76bf002be04e4271a8b69"
Commit 3a3655d6 authored by baberabb's avatar baberabb
Browse files

passed kwargs to client

parent fe358061
...@@ -15,6 +15,7 @@ def anthropic_completion( ...@@ -15,6 +15,7 @@ def anthropic_completion(
max_tokens_to_sample: int, max_tokens_to_sample: int,
temperature: float, temperature: float,
stop: List[str], stop: List[str],
**kwargs: Any,
): ):
"""Query Anthropic API for completion. """Query Anthropic API for completion.
...@@ -31,6 +32,7 @@ def anthropic_completion( ...@@ -31,6 +32,7 @@ def anthropic_completion(
stop_sequences=[anthropic.HUMAN_PROMPT] + stop, stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample, max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature, temperature=temperature,
**kwargs,
) )
return response.completion return response.completion
except anthropic.RateLimitError as e: except anthropic.RateLimitError as e:
...@@ -56,22 +58,29 @@ class AnthropicLM(LM): ...@@ -56,22 +58,29 @@ class AnthropicLM(LM):
batch_size=None, batch_size=None,
model: str = "claude-2.0", model: str = "claude-2.0",
max_tokens_to_sample: int = 256, max_tokens_to_sample: int = 256,
temperature: float = 0.0, temperature: float = 1.0, # defaults to 1
**kwargs: Any, # api_key, auth_token, etc. **kwargs: Any, # top_p, top_k, etc.
): # TODO: remove batch_size ): # TODO: remove batch_size
"""Anthropic API wrapper. """Anthropic API wrapper.
:param model: str :param model: str
Anthropic model e.g. 'claude-instant-v1', 'claude-2' Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
""" """
super().__init__() super().__init__()
self.model = model self.model = model
# defaults to os.environ.get("ANTHROPIC_API_KEY") # defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(**kwargs) self.client = anthropic.Anthropic()
self.temperature = temperature self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer() self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -123,6 +132,7 @@ class AnthropicLM(LM): ...@@ -123,6 +132,7 @@ class AnthropicLM(LM):
max_tokens_to_sample=self.max_tokens_to_sample, max_tokens_to_sample=self.max_tokens_to_sample,
temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, stop=until,
**self.kwargs,
) )
res.append(response) res.append(response)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment