Commit 471297ba authored by baberabb's avatar baberabb
Browse files

fixed generation_kwargs; added dependency groups to testing on CI

parent b8510001
...@@ -55,7 +55,7 @@ jobs: ...@@ -55,7 +55,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -49,12 +49,12 @@ class AnthropicLM(LM): ...@@ -49,12 +49,12 @@ class AnthropicLM(LM):
def __init__( def __init__(
self, self,
batch_size=None, batch_size: int = 1,
model: str = "claude-2.0", model: str = "claude-2.0",
max_tokens_to_sample: int = 256, max_tokens_to_sample: int = 256,
temperature: float = 1.0, # defaults to 1 temperature: float = 0, # defaults to 1
**kwargs: Any, # top_p, top_k, etc. **kwargs, # top_p, top_k, etc.
): # TODO: remove batch_size ):
"""Anthropic API wrapper. """Anthropic API wrapper.
:param model: str :param model: str
...@@ -119,13 +119,16 @@ class AnthropicLM(LM): ...@@ -119,13 +119,16 @@ class AnthropicLM(LM):
try: try:
inp = request[0] inp = request[0]
request_args = request[1] request_args = request[1]
until = request_args["until"] # generation_kwargs
until = request_args.get("until")
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_completion( response = anthropic_completion(
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=inp, prompt=inp,
max_tokens_to_sample=self.max_tokens_to_sample, max_tokens_to_sample=max_gen_toks,
temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, stop=until,
**self.kwargs, **self.kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment