Unverified Commit b65b9ca3 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #853 from baberabb/big-refactor_mps

[Refactor] float16 MPS works in torch nightly
parents b8d1cef9 bd81b8c0
import abc
import os
from typing import Union, List, Tuple
import torch
from typing import Union, List, Tuple, Optional, Type, TypeVar
from sqlitedict import SqliteDict
import json
import hashlib
......@@ -11,6 +12,8 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
......@@ -111,11 +114,28 @@ class LM(abc.ABC):
pass
@classmethod
def create_from_arg_string(cls, arg_string, additional_config=None):
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
Parameters:
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
- additional_config: Optional dictionary containing additional configuration parameters.
Returns:
- Instance of the LM class.
"""
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
if args2.get("device") == "mps" or args.get("device") == "mps":
# TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32"
return cls(**args, **args2)
......
......@@ -107,17 +107,20 @@ class HFLM(LM):
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu", "mps"]
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
)
if device:
if device not in device_list:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device == "mps":
if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info(
"MPS is still in beta and only supports float32; setting dtype to float32."
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
)
else:
eval_logger.info("Device not specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment