"src/include/Sequence.hpp" did not exist on "2a48812edb1a7c3e280159637fa89b7a0bbfb86b"
Commit bd81b8c0 authored by baberabb's avatar baberabb
Browse files

Update device list and dtype detection for MPS

parent b8d1cef9
import abc import abc
import os import os
from typing import Union, List, Tuple import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
import json import json
import hashlib import hashlib
...@@ -11,6 +12,8 @@ from tqdm import tqdm ...@@ -11,6 +12,8 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
T = TypeVar("T", bound="LM")
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self) -> None: def __init__(self) -> None:
...@@ -111,11 +114,28 @@ class LM(abc.ABC): ...@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass pass
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, additional_config=None): def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps": # TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32" args["dtype"] = "float32"
return cls(**args, **args2) return cls(**args, **args2)
......
...@@ -107,17 +107,20 @@ class HFLM(LM): ...@@ -107,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu", "mps"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
if device: if device:
if device not in device_list: if device not in device_list:
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device == "mps": if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info( eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32." "MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
) )
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment