Unverified Commit 1da9e4e8 authored by Nikodem Szwast's avatar Nikodem Szwast Committed by GitHub
Browse files

Update authentications methods, add support for deployment_id for IBM watsonx_ai (#2877)

* update authnentications methods, add support for deployment_id

* run pre-commit on changed file
parent 22bd2bcb
......@@ -2,6 +2,7 @@ import copy
import json
import logging
import os
import warnings
from functools import lru_cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
......@@ -22,37 +23,49 @@ class LogLikelihoodResult(NamedTuple):
is_greedy: bool
def _verify_credentials(creds: Any) -> None:
def _verify_credentials(creds: dict) -> None:
"""
Verifies that all required keys are present in the credentials dictionary.
Args:
creds (Any): A dictionary containing the credentials.
Raises:
ValueError: If any of the necessary credentials are missing, with guidance
on which environment variables need to be set.
Validate credentials for APIClient authentication.
Required conditions:
- Either ("username" and "password") or "apikey" must be present.
- "url" is mandatory.
- Either "project_id" or "space_id" must be present.
"""
env_var_map = {
"apikey": "WATSONX_API_KEY",
"token": "WATSONX_TOKEN",
"url": "WATSONX_URL",
"project_id": "WATSONX_PROJECT_ID",
"space_id": "WATSONX_SPACE_ID",
"username": "WATSONX_USERNAME",
"password": "WATSONX_PASSWORD",
}
auth_req_keys = ["apikey", "token"] # one of these is required
other_req_keys = ["url", "project_id"] # both of these are required
missing_auth = all(not creds.get(key) for key in auth_req_keys) # flag any missing
missing_keys = [
f"{key} ({env_var_map[key]})" for key in other_req_keys if not creds.get(key)
] # list all missing
if missing_keys or missing_auth:
error_msg = f"Missing required credentials: {', '.join(missing_keys)}"
if missing_auth:
error_msg += ", and " if missing_keys else ""
error_msg += f"either {' or '.join([f'{key} ({env_var_map[key]})' for key in auth_req_keys])}"
# Check authentication: Either ("username" and "password") or "apikey" must be provided
has_auth = all(creds.get(key) for key in ["username", "password"]) or creds.get(
"apikey"
)
# Check required fields: "url" must be present
has_url = "url" in creds and creds["url"]
# Check project/space ID requirement: Either "project_id" or "space_id" must be present
has_project_or_space_id = any(creds.get(key) for key in ["project_id", "space_id"])
if not (has_auth and has_url and has_project_or_space_id):
missing_keys = []
if not has_auth:
missing_keys.append(
f"either ('username' and 'password') or 'apikey' ({env_var_map['apikey']})"
)
if not has_url:
missing_keys.append(f"url ({env_var_map['url']})")
if not has_project_or_space_id:
missing_keys.append(
f"either 'project_id' ({env_var_map['project_id']}) or 'space_id' ({env_var_map['space_id']})"
)
error_msg += ". Please set the environment variables indicated in parenthesis."
error_msg = f"Missing required credentials: {', '.join(missing_keys)}. "
error_msg += "Please set the environment variables indicated in parentheses."
raise ValueError(error_msg)
......@@ -78,12 +91,23 @@ def get_watsonx_credentials() -> Dict[str, str]:
load_dotenv()
credentials = {
"username": os.getenv("WATSONX_USERNAME", None),
"password": os.getenv("WATSONX_PASSWORD", None),
"apikey": os.getenv("WATSONX_API_KEY", None),
"token": os.getenv("WATSONX_TOKEN", None),
"url": os.getenv("WATSONX_URL", None),
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
"space_id": os.getenv("WATSONX_SPACE_ID", None),
}
if "cloud.ibm.com" not in credentials["url"]:
credentials["instance_id"] = "openshift"
if all(credentials.get(key) for key in ["username", "password", "apikey"]):
warnings.warn(
"You're passing `username`, `password`, and `apikey` at the same time, "
"which might cause issues. More info on authentication in different scenarios "
"can be found in the docs: https://ibm.github.io/watsonx-ai-python-sdk/setup_cpd.html"
)
_verify_credentials(credentials)
return credentials
......@@ -115,8 +139,11 @@ class WatsonxLLM(LM):
args.update(additional_config)
model_id = args.pop("model_id", None)
if model_id is None:
raise ValueError("'model_id' is required, please pass it in 'model_args'")
deployment_id = args.pop("deployment_id", None)
if model_id is None and deployment_id is None:
raise ValueError(
"'model_id' or 'deployment_id' is required, please pass it in 'model_args'"
)
if not args.get("do_sample", None):
args["temperature"] = None
......@@ -152,6 +179,7 @@ class WatsonxLLM(LM):
return cls(
watsonx_credentials=get_watsonx_credentials(),
model_id=model_id,
deployment_id=deployment_id,
generate_params=generate_params,
)
......@@ -159,6 +187,7 @@ class WatsonxLLM(LM):
self,
watsonx_credentials: Dict,
model_id,
deployment_id,
generate_params: Optional[Dict[Any, Any]] = None,
) -> None:
try:
......@@ -171,7 +200,6 @@ class WatsonxLLM(LM):
super().__init__()
client = APIClient(watsonx_credentials)
project_id = watsonx_credentials.get("project_id", None)
deployment_id = watsonx_credentials.get("deployment_id", None)
client.set.default_project(project_id)
self.generate_params = generate_params
self.model = ModelInference(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment