Unverified Commit 78d57e0f authored by Kiersten Stokes's avatar Kiersten Stokes Committed by GitHub
Browse files

Add support for token-based auth for watsonx models (#2796)

* Add support for token-based auth for watsonx models

* Fix lint

* Move dotenv import to inner scope

* Improve readability of _verify_credentials
parent 6fbebb4b
......@@ -28,21 +28,32 @@ def _verify_credentials(creds: Any) -> None:
Args:
creds (Any): A dictionary containing the credentials.
Raises:
ValueError: If any of the necessary credentials are missing, with guidance on which environment variables need to be set.
ValueError: If any of the necessary credentials are missing, with guidance
on which environment variables need to be set.
"""
required_keys = ["apikey", "url", "project_id"]
env_var_mapping = {
env_var_map = {
"apikey": "WATSONX_API_KEY",
"token": "WATSONX_TOKEN",
"url": "WATSONX_URL",
"project_id": "WATSONX_PROJECT_ID",
}
missing_keys = [key for key in required_keys if key not in creds or not creds[key]]
if missing_keys:
missing_env_vars = [env_var_mapping[key] for key in missing_keys]
raise ValueError(
f"Missing required credentials: {', '.join(missing_keys)}. Please set the following environment variables: {', '.join(missing_env_vars)}"
)
auth_req_keys = ["apikey", "token"] # one of these is required
other_req_keys = ["url", "project_id"] # both of these are required
missing_auth = all(not creds.get(key) for key in auth_req_keys) # flag any missing
missing_keys = [
f"{key} ({env_var_map[key]})" for key in other_req_keys if not creds.get(key)
] # list all missing
if missing_keys or missing_auth:
error_msg = f"Missing required credentials: {', '.join(missing_keys)}"
if missing_auth:
error_msg += ", and " if missing_keys else ""
error_msg += f"either {' or '.join([f'{key} ({env_var_map[key]})' for key in auth_req_keys])}"
error_msg += ". Please set the environment variables indicated in parenthesis."
raise ValueError(error_msg)
@lru_cache(maxsize=None)
......@@ -51,13 +62,24 @@ def get_watsonx_credentials() -> Dict[str, str]:
Retrieves Watsonx API credentials from environmental variables.
Returns:
Dict[str, str]: A dictionary containing the credentials necessary for authentication, including
keys such as `apikey`, `url`, and `project_id`.
keys such as `apikey` or `token`, `url`, and `project_id`.
Raises:
AssertionError: If the credentials format is invalid or any of the necessary credentials are missing.
"""
try:
from dotenv import load_dotenv
except ImportError:
raise ImportError(
"Could not import dotenv: Please install lm_eval[ibm_watsonx_ai] package."
)
# This function attempts to load a file named .env starting from the CWD and working backwards
# towards root. KV pairs are parsed and stored as env vars iff not already set
load_dotenv()
credentials = {
"apikey": os.getenv("WATSONX_API_KEY", None),
"token": os.getenv("WATSONX_TOKEN", None),
"url": os.getenv("WATSONX_URL", None),
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
}
......
......@@ -63,7 +63,7 @@ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "unitxt"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"]
ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22"]
ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment