Commit 3a3655d6 authored by baberabb's avatar baberabb
Browse files

passed kwargs to client

parent fe358061
......@@ -15,6 +15,7 @@ def anthropic_completion(
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
**kwargs: Any,
):
"""Query Anthropic API for completion.
......@@ -31,6 +32,7 @@ def anthropic_completion(
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
**kwargs,
)
return response.completion
except anthropic.RateLimitError as e:
......@@ -56,22 +58,29 @@ class AnthropicLM(LM):
batch_size=None,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0.0,
**kwargs: Any, # api_key, auth_token, etc.
temperature: float = 1.0, # defaults to 1
**kwargs: Any, # top_p, top_k, etc.
): # TODO: remove batch_size
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
"""
super().__init__()
self.model = model
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(**kwargs)
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property
def eot_token_id(self):
......@@ -123,6 +132,7 @@ class AnthropicLM(LM):
max_tokens_to_sample=self.max_tokens_to_sample,
temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until,
**self.kwargs,
)
res.append(response)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment