Commit bd81b8c0 authored by baberabb's avatar baberabb
Browse files

Update device list and dtype detection for MPS

parent b8d1cef9
import abc
import os
from typing import Union, List, Tuple
import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict
import json
import hashlib
......@@ -11,6 +12,8 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
......@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass
@classmethod
def create_from_arg_string(cls, arg_string, additional_config=None):
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps":
# TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32"
return cls(**args, **args2)
......
......@@ -107,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu", "mps"]
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
)
if device:
if device not in device_list:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device == "mps":
if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32."
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
)
else:
eval_logger.info("Device not specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment