Commit 7bb147b5 authored by baberabb's avatar baberabb
Browse files

change torch req for mps

parent dc5b3d5d
......@@ -170,7 +170,7 @@ A number of other libraries contain scripts for calling the eval harness through
### Additional Features
If you have a CUDA-compatible Mac GPU, you can run the eval harness using the MPS back-end by replaicng `--device cuda:0` with `--device mps:0`. PyTorch does not currently support automatic mixed precision (AMP) for MPS, so we forcibly cast all weights to fp32 regardless of how they're stored. This is slower and has a larger memory footprint than we can achieve on Linux systems, but as PyTorch continues to improve its MPS support we hope to continue to improve it.
If you have a Metal compatible Mac, you can run the eval harness using the MPS back-end by replacing `--device cuda:0` with `--device mps` (requires PyTorch version 2.1 or higher).
💡 **Tip**: You can inspect what the LM inputs look like by running the following command:
......
......@@ -133,13 +133,6 @@ class LM(abc.ABC):
additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
# TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32"
return cls(**args, **args2)
@property
......
import os
from packaging import version
import torch
import transformers
from transformers.models.auto.modeling_auto import (
......@@ -118,11 +118,11 @@ class HFLM(LM):
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device in ("mps", "mps:0") and "dev" not in torch.__version__:
eval_logger.info(
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of "
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url "
"https://download.pytorch.org/whl/nightly/cpu"
if device in ("mps", "mps:0") and version.parse(
torch.__version__
) < version.parse("2.1"):
raise RuntimeError(
f"mps requires torch >= 2.1. You have {torch.__version__}"
)
else:
eval_logger.info("Device not specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment