Unverified Commit 4259a6d4 authored by Nikodem Szwast's avatar Nikodem Szwast Committed by GitHub
Browse files

IBM watsonx_llm fixes & refactor (#2464)

* refactor code, fix config path bug

* update types to be from typing lib

* add pre-commit formatting

* specify version of ibm_watsonx_ai package

* adjust get_watsonx_credentials() function, add minor refactor to adress PR review comments

* change missing installation hint from ibm_watsonx_ai to lm_eval[ibm_watsonx_ai]
parent 67db63a5
import json
import copy
import os
from configparser import ConfigParser
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -18,57 +16,46 @@ class LogLikelihoodResult(NamedTuple):
is_greedy: bool
@lru_cache(maxsize=None)
def get_watsonx_credentials(
env_name: str = "YP_QA",
config_path: str = "config.ini",
) -> Dict[str, str]:
def _verify_credentials(creds: Any) -> None:
"""
Retrieves Watsonx API credentials from environmental variables or from a configuration file.
Verifies that all required keys are present in the credentials dictionary.
Args:
env_name (str, optional): The name of the environment from which to retrieve credentials. Defaults to "YP_QA".
config_path (str, optional): The file path to the `config.ini` configuration file. Defaults to "config.ini".
creds (Any): A dictionary containing the credentials.
Raises:
ValueError: If any of the necessary credentials are missing, with guidance on which environment variables need to be set.
"""
required_keys = ["apikey", "url", "project_id"]
env_var_mapping = {
"apikey": "WATSONX_API_KEY",
"url": "WATSONX_URL",
"project_id": "WATSONX_PROJECT_ID",
}
missing_keys = [key for key in required_keys if key not in creds or not creds[key]]
if missing_keys:
missing_env_vars = [env_var_mapping[key] for key in missing_keys]
raise ValueError(
f"Missing required credentials: {', '.join(missing_keys)}. Please set the following environment variables: {', '.join(missing_env_vars)}"
)
@lru_cache(maxsize=None)
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
Returns:
dict[str, str]: A dictionary containing the credentials necessary for authentication, including
Dict[str, str]: A dictionary containing the credentials necessary for authentication, including
keys such as `apikey`, `url`, and `project_id`.
Raises:
FileNotFoundError: If the specified configuration file does not exist.
AssertionError: If the credentials format is invalid.
AssertionError: If the credentials format is invalid or any of the necessary credentials are missing.
"""
def _verify_credentials(creds: Any) -> None:
assert isinstance(creds, dict) and all(
key in creds.keys() for key in ["apikey", "url", "project_id"]
), "Wrong configuration for credentials."
credentials = {
"apikey": os.getenv("WATSONX_API_KEY", None),
"url": os.getenv("WATSONX_URL", None),
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
}
if any(credentials.get(key) is None for key in ["apikey", "url", "project_id"]):
eval_logger.warning(
"One or more required environment variables are missing, trying to load config.ini file."
)
config_path = "config.ini" if not config_path else config_path
if not Path(config_path).is_absolute():
config_path = os.path.join(
Path(__file__).parent.parent.absolute(), config_path
)
if not os.path.exists(config_path):
raise FileNotFoundError(
f"Provided config file path {config_path} does not exist. "
"You need to specify credentials in config.ini file under specified location."
)
config = ConfigParser()
config.read(config_path)
credentials = json.loads(config.get(env_name))
_verify_credentials(credentials)
return credentials
......@@ -84,7 +71,7 @@ class WatsonxLLM(LM):
def create_from_arg_string(
cls: Type["WatsonxLLM"],
arg_string: str,
config_path: Optional[str] = None,
additional_config: Optional[Dict] = None,
) -> "WatsonxLLM":
"""
Allow the user to specify model parameters (TextGenerationParameters) in CLI arguments.
......@@ -97,6 +84,8 @@ class WatsonxLLM(LM):
)
args = simple_parse_args_string(arg_string)
args.update(additional_config)
model_id = args.pop("model_id", None)
if model_id is None:
raise ValueError("'model_id' is required, please pass it in 'model_args'")
......@@ -107,7 +96,7 @@ class WatsonxLLM(LM):
args["top_k"] = None
args["seed"] = None
cls.generate_params = {
generate_params = {
GenParams.DECODING_METHOD: (
"greedy" if not args.get("do_sample", None) else "sample"
),
......@@ -130,12 +119,10 @@ class WatsonxLLM(LM):
},
}
generate_params = {
k: v for k, v in cls.generate_params.items() if v is not None
}
generate_params = {k: v for k, v in generate_params.items() if v is not None}
return cls(
watsonx_credentials=get_watsonx_credentials(config_path),
watsonx_credentials=get_watsonx_credentials(),
model_id=model_id,
generate_params=generate_params,
)
......@@ -158,7 +145,7 @@ class WatsonxLLM(LM):
project_id = watsonx_credentials.get("project_id", None)
deployment_id = watsonx_credentials.get("deployment_id", None)
client.set.default_project(project_id)
self.generate_params = generate_params or {}
self.generate_params = generate_params
self.model = ModelInference(
model_id=model_id,
deployment_id=deployment_id,
......@@ -220,9 +207,9 @@ class WatsonxLLM(LM):
"""
Calculates the log likelihood of the generated tokens compared to the context tokens.
Args:
input_tokens (List[dict[str, float]]): A List of token dictionaries, each containing
input_tokens (List[Dict[str, float]]): A List of token dictionaries, each containing
token information like `text` and `logprob`.
context_tokens (List[dict[str, float]]): A List of token dictionaries representing
context_tokens (List[Dict[str, float]]): A List of token dictionaries representing
the input context.
Returns:
LogLikelihoodResult: An object containing the calculated log likelihood and a boolean
......@@ -252,27 +239,24 @@ class WatsonxLLM(LM):
Returns:
List[str]: A List of generated responses.
"""
requests = [request.args[0] for request in requests]
requests = [request.args for request in requests]
results = []
batch_size = 5
for i in tqdm(
range(0, len(requests), batch_size),
desc=f"Running generate_until function with batch size {batch_size}",
for request in tqdm(
requests,
desc="Running generate_until function ...",
):
batch = requests[i : i + batch_size]
context, continuation = request
try:
responses = self.model.generate_text(batch, self.generate_params)
response = self.model.generate_text(context, self.generate_params)
except Exception as exp:
eval_logger.error(f"Error while generating text {exp}")
continue
for response, context in zip(responses, batch):
results.append(response)
self.cache_hook.add_partial("generated_text", context, response)
eval_logger.error("Error while generating text.")
raise exp
eval_logger.info("Cached responses")
results.append(response)
self.cache_hook.add_partial(
"generate_until", (context, continuation), response
)
return results
......@@ -284,7 +268,7 @@ class WatsonxLLM(LM):
2. a target string on which the loglikelihood of the LM producing this target,
conditioned on the input, will be returned.
Returns:
tuple (loglikelihood, is_greedy) for each request according to the input order:
Tuple (loglikelihood, is_greedy) for each request according to the input order:
loglikelihood: probability of generating the target string conditioned on the input
is_greedy: True if and only if the target string would be generated by greedy sampling from the LM
"""
......@@ -295,54 +279,48 @@ class WatsonxLLM(LM):
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
self.generate_params[GenParams.MAX_NEW_TOKENS] = 1
generate_params = copy.copy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
batch_size = 5
for i in tqdm(
range(0, len(requests), batch_size),
desc=f"Running loglikelihood function with batch size {batch_size}",
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood function ...",
):
batch = requests[i : i + batch_size]
context, continuation = request
try:
tokenized_contexts = [
self.model.tokenize(prompt=context, return_tokens=True)["result"][
"tokens"
]
for context, _ in batch
]
tokenized_context = self.model.tokenize(
prompt=context, return_tokens=True
)["result"]["tokens"]
except Exception as exp:
eval_logger.error(f"Error while model tokenize:\n {exp}")
continue
eval_logger.error("Error while model tokenize.")
raise exp
input_prompts = [context + continuation for context, continuation in batch]
input_prompt = context + continuation
try:
responses = self.model.generate_text(
prompt=input_prompts, params=self.generate_params, raw_response=True
response = self.model.generate_text(
prompt=input_prompt, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error(f"Error while model generate text:\n {exp}")
continue
for response, tokenized_context, (context, continuation) in zip(
responses, tokenized_contexts, batch
):
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], tokenized_context
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood",
(context, continuation),
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
eval_logger.info("Cached batch")
eval_logger.error("Error while model generate text.")
raise exp
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], tokenized_context
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood",
(context, continuation),
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
return cast(List[Tuple[float, bool]], results)
......@@ -350,10 +328,10 @@ class WatsonxLLM(LM):
"""
Used to evaluate perplexity on a data distribution.
Args:
requests: Each request contains Instance.args : tuple[str] containing an input string to the model whose
requests: Each request contains Instance.args : Tuple[str] containing an input string to the model whose
entire loglikelihood, conditioned on purely the EOT token, will be calculated.
Returns:
tuple (loglikelihood,) for each request according to the input order:
Tuple (loglikelihood,) for each request according to the input order:
loglikelihood: solely the probability of producing each piece of text given no starting input.
"""
try:
......@@ -363,47 +341,34 @@ class WatsonxLLM(LM):
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
)
self._check_model_logprobs_support()
self.generate_params[GenParams.MAX_NEW_TOKENS] = 1
generate_params = copy.deepcopy(self.generate_params)
generate_params[GenParams.MAX_NEW_TOKENS] = 1
requests = [request.args[0] for request in requests]
requests = [request.args for request in requests]
results: List[LogLikelihoodResult] = []
batch_size = 5
for i in tqdm(
range(0, len(requests), batch_size),
desc=f"Running loglikelihood_rolling function with batch size {batch_size}",
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
for request in tqdm(
requests,
desc="Running loglikelihood_rolling function ...",
):
batch = requests[i : i + batch_size]
context, continuation = request
try:
responses = self.model.generate_text(
prompt=batch, params=self.generate_params, raw_response=True
response = self.model.generate_text(
prompt=context, params=generate_params, raw_response=True
)
except Exception as exp:
eval_logger.error(f"Error while model generate text:\n {exp}")
continue
for response, context in zip(responses, batch):
try:
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], []
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood_rolling",
context,
(
log_likelihood_response.log_likelihood,
log_likelihood_response.is_greedy,
),
)
except Exception as exp:
eval_logger.error(
f"Error during log likelihood calculation:\n {exp}"
)
continue
eval_logger.info("Cached batch")
eval_logger.error("Error while model generate text.")
raise exp
log_likelihood_response = self._get_log_likelihood(
response["results"][0]["input_tokens"], []
)
results.append(log_likelihood_response)
self.cache_hook.add_partial(
"loglikelihood_rolling",
(context, continuation),
log_likelihood_response.log_likelihood,
)
return cast(List[Tuple[float, bool]], results)
......@@ -62,7 +62,7 @@ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"]
ibm_watsonx_ai = ["ibm_watsonx_ai"]
ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment