Commit 471297ba authored by baberabb's avatar baberabb
Browse files

fixed generation_kwargs; added dependency groups to testing on CI

parent b8510001
......@@ -55,7 +55,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
......@@ -49,12 +49,12 @@ class AnthropicLM(LM):
def __init__(
self,
batch_size=None,
batch_size: int = 1,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 1.0, # defaults to 1
**kwargs: Any, # top_p, top_k, etc.
): # TODO: remove batch_size
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
"""Anthropic API wrapper.
:param model: str
......@@ -119,13 +119,16 @@ class AnthropicLM(LM):
try:
inp = request[0]
request_args = request[1]
until = request_args["until"]
# generation_kwargs
until = request_args.get("until")
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_completion(
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=self.max_tokens_to_sample,
temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic
max_tokens_to_sample=max_gen_toks,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until,
**self.kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment