Commit 7bb147b5 authored by baberabb's avatar baberabb
Browse files

change torch req for mps

parent dc5b3d5d
...@@ -170,7 +170,7 @@ A number of other libraries contain scripts for calling the eval harness through ...@@ -170,7 +170,7 @@ A number of other libraries contain scripts for calling the eval harness through
### Additional Features ### Additional Features
If you have a CUDA-compatible Mac GPU, you can run the eval harness using the MPS back-end by replaicng `--device cuda:0` with `--device mps:0`. PyTorch does not currently support automatic mixed precision (AMP) for MPS, so we forcibly cast all weights to fp32 regardless of how they're stored. This is slower and has a larger memory footprint than we can achieve on Linux systems, but as PyTorch continues to improve its MPS support we hope to continue to improve it. If you have a Metal compatible Mac, you can run the eval harness using the MPS back-end by replacing `--device cuda:0` with `--device mps` (requires PyTorch version 2.1 or higher).
💡 **Tip**: You can inspect what the LM inputs look like by running the following command: 💡 **Tip**: You can inspect what the LM inputs look like by running the following command:
......
...@@ -133,13 +133,6 @@ class LM(abc.ABC): ...@@ -133,13 +133,6 @@ class LM(abc.ABC):
additional_config = {} if additional_config is None else additional_config additional_config = {} if additional_config is None else additional_config
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
# TODO: delete once float16 MPS is fixed in torch stable
if (
args2.get("device") in ("mps", "mps:0")
or args.get("device") in ("mps", "mps:0")
and "dev" not in torch.__version__
):
args["dtype"] = "float32"
return cls(**args, **args2) return cls(**args, **args2)
@property @property
......
import os import os
from packaging import version
import torch import torch
import transformers import transformers
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
...@@ -118,11 +118,11 @@ class HFLM(LM): ...@@ -118,11 +118,11 @@ class HFLM(LM):
device = int(device) device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device in ("mps", "mps:0") and "dev" not in torch.__version__: if device in ("mps", "mps:0") and version.parse(
eval_logger.info( torch.__version__
"MPS: Setting dtype to float32. To use float16 with MPS, please install a nightly build of " ) < version.parse("2.1"):
"PyTorch: pip3 install --pre torch torchvision torchaudio --index-url " raise RuntimeError(
"https://download.pytorch.org/whl/nightly/cpu" f"mps requires torch >= 2.1. You have {torch.__version__}"
) )
else: else:
eval_logger.info("Device not specified") eval_logger.info("Device not specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment