"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6083c1566e261668a5de73cfe484c171ce232812"
Unverified Commit 9cc65f87 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Migrate torchdynamo to torch.compile (#20634)

* Migrate torchdynamo to torch.compile

* Add docstring and generic option

* Properly use the function...

* Reorg args
parent da95f6ca
......@@ -718,11 +718,11 @@ For some applications, such as pretraining large language models, applying all t
Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many).
## Inference with torchdynamo
## Using torch.compile
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
PyTorch 2.0 introduces a new compile function, you can learn more about it [in their documentation](https://pytorch.org/get-started/pytorch-2.0/). It uses Python’s frame evaluation API to automatically create a graph from existing PyTorch programs. After capturing the graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
TorchDynamo has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
`torch.compile` has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
or `torchdynamo.list_backends()` each of which with its optional dependencies.
Some of the most commonly used backends are
......
......@@ -144,8 +144,8 @@ from .utils import (
is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_compile_available,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
)
from .utils.generic import ContextManagers
......@@ -642,9 +642,9 @@ class Trainer:
# very last
self._memory_tracker.stop_and_update_metrics()
# torchdynamo
if args.torchdynamo is not None and not is_torchdynamo_available():
raise RuntimeError("Using torchdynamo requires a nighly install of PyTorch.")
# torch.compile
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires a nighly install of PyTorch.")
def add_callback(self, callback):
"""
......@@ -1321,10 +1321,9 @@ class Trainer:
return model
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torchdynamo is not None:
import torch._dynamo as dynamo
if self.args.torch_compile:
model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode)
model = dynamo.optimize(self.args.torchdynamo)(model)
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
......
......@@ -73,7 +73,7 @@ log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
DYNAMO_BACKENDS = [
TORCH_COMPILE_BACKENDS = [
"eager",
"aot_eager",
"inductor",
......@@ -514,6 +514,21 @@ class TrainingArguments:
information.
use_mps_device (`bool`, *optional*, defaults to `False`):
Whether to use Apple Silicon chip based `mps` device.
torch_compile (`bool`, *optional*, defaults to `False`):
Whether or not to compile the model using PyTorch 2.0
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/) (requires a nighlty install of PyTorch).
If set, the backend will default to `"inductor"` (can be customized with `torch_compile_backend`) and the
mode will default to `"default"` (can be customized with `torch_compile_mode`).
torch_compile_backend (`str`, *optional*):
The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`, `"nvfuser"`, `"aot_nvfuser"`,
`"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
torch_compile_mode (`str`, *optional*):
The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
Possible choices are `"default"`, `"reduce-overhead"` and `"max-autotune"`.
"""
framework = "pt"
......@@ -983,8 +998,8 @@ class TrainingArguments:
torchdynamo: Optional[str] = field(
default=None,
metadata={
"help": "Sets up the backend compiler for TorchDynamo.",
"choices": DYNAMO_BACKENDS,
"help": "This argument is deprecated, use `--torch_compile_backend` instead.",
"choices": TORCH_COMPILE_BACKENDS,
},
)
ray_scope: Optional[str] = field(
......@@ -1006,6 +1021,23 @@ class TrainingArguments:
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
},
)
torch_compile: bool = field(
default=False, metadata={"help": "If set to `True`, the model will be wrapped in `torch.compile`."}
)
torch_compile_backend: Optional[str] = field(
default=None,
metadata={
"help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
"choices": TORCH_COMPILE_BACKENDS,
},
)
torch_compile_mode: Optional[str] = field(
default=None,
metadata={
"help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
"choices": ["default", "reduce-overhead", "max-autotune"],
},
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
......@@ -1148,10 +1180,24 @@ class TrainingArguments:
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)
if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None:
if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
" `torch_compile_backend` instead",
FutureWarning,
)
self.torch_compile_backend = self.torchdynamo
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
self.torch_compile = True
if self.torch_compile and self.torch_compile_backend is None:
self.torch_compile_backend = "inductor"
if self.framework == "pt" and is_torch_available() and self.torch_compile:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16:
logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.")
logger.info(
"Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement"
" otherwise."
)
torch.backends.cuda.matmul.allow_tf32 = True
else:
logger.warning(
......
......@@ -148,6 +148,7 @@ from .import_utils import (
is_torch_bf16_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_compile_available,
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
......
......@@ -455,6 +455,15 @@ def is_torchdynamo_available():
return False
def is_torch_compile_available():
if not is_torch_available():
return False
import torch
return hasattr(torch, "compile")
def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None:
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment