Unverified Commit 56def33d authored by Ramiro R. C.'s avatar Ramiro R. C. Committed by GitHub
Browse files

Custom request headers | trust_remote_code param fix (#3069)



* added headers and custom model name | fixed bug with trust_remote_code param

* linting

* removed custom model name | changed headers override

* add `header` to base TemplateAPI

* nit

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent e6ea0315
......@@ -21,7 +21,11 @@ When subclassing `TemplateAPI`, you need to implement the following methods:
1. `_create_payload`: Creates the JSON payload for API requests.
2. `parse_logprobs`: Parses log probabilities from API responses.
3. `parse_generations`: Parses generated text from API responses.
4. `headers`: Returns the headers for the API request.
Optional Properties:
4. `header`: Returns the headers for the API request.
5. `api_key`: Returns the API key for authentication (if required).
You may also need to override other methods or properties depending on your API's specific requirements.
......@@ -97,6 +101,10 @@ When initializing a `TemplateAPI` instance or a subclass, you can provide severa
- Whether to validate the certificate of the API endpoint (if HTTPS).
- Default is True.
- `header` (dict, optional):
- Custom headers for API requests.
- If not provided, uses `{"Authorization": f"Bearer {self.api_key}"}` by default.
Example usage:
```python
......
......@@ -436,7 +436,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True"
if isinstance(args.model_args, dict):
args.model_args["trust_remote_code"] = True
else:
args.model_args = args.model_args + ",trust_remote_code=True"
(
eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO
......
......@@ -135,6 +135,7 @@ class TemplateAPI(TemplateLM):
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
header: Optional[Dict[str, str]] = None,
max_images: int = 1,
**kwargs,
) -> None:
......@@ -152,6 +153,7 @@ class TemplateAPI(TemplateLM):
self.model = model or pretrained
self.base_url = base_url
self.tokenizer = tokenizer
self._header = header
if not isinstance(batch_size, int) and "auto" in batch_size:
eval_logger.warning(
"Automatic batch size is not supported for API models. Defaulting to batch size 1."
......@@ -296,7 +298,7 @@ class TemplateAPI(TemplateLM):
@cached_property
def header(self) -> dict:
"""Override this property to return the headers for the API request."""
return {"Authorization": f"Bearer {self.api_key}"}
return self._header or {"Authorization": f"Bearer {self.api_key}"}
@property
def tokenizer_name(self) -> str:
......
......@@ -16,8 +16,8 @@ eval_logger = logging.getLogger(__name__)
class LocalCompletionsAPI(TemplateAPI):
def __init__(
self,
base_url=None,
tokenizer_backend="huggingface",
base_url: str = None,
tokenizer_backend: str = "huggingface",
**kwargs,
):
super().__init__(
......@@ -108,9 +108,9 @@ class LocalCompletionsAPI(TemplateAPI):
class LocalChatCompletion(LocalCompletionsAPI):
def __init__(
self,
base_url=None,
tokenizer_backend=None,
tokenized_requests=False,
base_url: str = None,
tokenizer_backend: str = None,
tokenized_requests: bool = False,
**kwargs,
):
eval_logger.warning(
......@@ -236,6 +236,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
eval_logger.warning(
"o1 models do not support `stop` and only support temperature=1"
)
super().__init__(
base_url=base_url,
tokenizer_backend=tokenizer_backend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment