"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2eaaf17a0b0ab4c13cb1b1e87accd2d5dee47be4"
Unverified Commit 3a35937e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove backend check for torch.compile (#22140)



* Remove backend enforcment for torch.compile

* Update error

* Update src/transformers/training_args.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Style

---------
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 618697ef
...@@ -672,7 +672,7 @@ class Trainer: ...@@ -672,7 +672,7 @@ class Trainer:
# torch.compile # torch.compile
if args.torch_compile and not is_torch_compile_available(): if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.") raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
def add_callback(self, callback): def add_callback(self, callback):
""" """
......
...@@ -85,20 +85,6 @@ log_levels = logging.get_log_levels_dict().copy() ...@@ -85,20 +85,6 @@ log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1) trainer_log_levels = dict(**log_levels, passive=-1)
TORCH_COMPILE_BACKENDS = [
"eager",
"aot_eager",
"inductor",
"nvfuser",
"aot_nvfuser",
"aot_cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"ipex",
]
def default_logdir() -> str: def default_logdir() -> str:
""" """
Same default as PyTorch Same default as PyTorch
...@@ -571,17 +557,24 @@ class TrainingArguments: ...@@ -571,17 +557,24 @@ class TrainingArguments:
Whether or not to compile the model using PyTorch 2.0 Whether or not to compile the model using PyTorch 2.0
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/) (requires a nighlty install of PyTorch). [`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/) (requires a nighlty install of PyTorch).
If set, the backend will default to `"inductor"` (can be customized with `torch_compile_backend`) and the This will use the best defaults for the [`torch.compile`
mode will default to `"default"` (can be customized with `torch_compile_mode`). API](https://pytorch.org/docs/2.0/generated/torch.compile.html?highlight=torch+compile#torch.compile). You
can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we don't
guarantee any of them will work as the support is progressively rolled in in PyTorch.
This flag and the whole compile API is experimental and subject to change in future releases.
torch_compile_backend (`str`, *optional*): torch_compile_backend (`str`, *optional*):
The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`. The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, `"nvfuser"`, `"aot_nvfuser"`, Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
`"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
This flag is experimental and subject to change in future releases.
torch_compile_mode (`str`, *optional*): torch_compile_mode (`str`, *optional*):
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`. The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Possible choices are `"default"`, `"reduce-overhead"` and `"max-autotune"`. Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
This flag is experimental and subject to change in future releases.
""" """
framework = "pt" framework = "pt"
...@@ -1061,7 +1054,6 @@ class TrainingArguments: ...@@ -1061,7 +1054,6 @@ class TrainingArguments:
default=None, default=None,
metadata={ metadata={
"help": "This argument is deprecated, use `--torch_compile_backend` instead.", "help": "This argument is deprecated, use `--torch_compile_backend` instead.",
"choices": TORCH_COMPILE_BACKENDS,
}, },
) )
ray_scope: Optional[str] = field( ray_scope: Optional[str] = field(
...@@ -1090,14 +1082,12 @@ class TrainingArguments: ...@@ -1090,14 +1082,12 @@ class TrainingArguments:
default=None, default=None,
metadata={ metadata={
"help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.", "help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
"choices": TORCH_COMPILE_BACKENDS,
}, },
) )
torch_compile_mode: Optional[str] = field( torch_compile_mode: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.", "help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
"choices": ["default", "reduce-overhead", "max-autotune"],
}, },
) )
......
...@@ -478,6 +478,8 @@ def is_torch_compile_available(): ...@@ -478,6 +478,8 @@ def is_torch_compile_available():
import torch import torch
# We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against
# 2.0 but let's do it later.
return hasattr(torch, "compile") return hasattr(torch, "compile")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment