"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "df96438484b62516689d67c00d4d9188f42e29ca"
Unverified Commit 08b46218 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Repurpose torchdynamo training args towards torch._dynamo (#20498)

* Repurpose torchdynamo training args towards torch._dynamo

* Add doc
parent 829374e4
...@@ -720,16 +720,25 @@ Another use case for training on many GPUs is if the model does not fit on a sin ...@@ -720,16 +720,25 @@ Another use case for training on many GPUs is if the model does not fit on a sin
## Inference with torchdynamo ## Inference with torchdynamo
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost. TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
``` TorchDynamo has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost or `torchdynamo.list_backends()` each of which with its optional dependencies.
TrainingArguments(torchdynamo="nvfuser") #enable nvfuser
TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32 Some of the most commonly used backends are
TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16
``` **Debugging backends**:
* `dynamo.optimize("eager")` - Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo issues.
* `dynamo.optimize("aot_eager")` - Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
**Training & inference backends**:
* `dynamo.optimize("inductor")` - Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels [Read more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
* `dynamo.optimize("nvfuser")` - nvFuser with TorchScript. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
* `dynamo.optimize("aot_nvfuser")` - nvFuser with AotAutograd. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
* `dynamo.optimize("aot_cudagraphs")` - cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757)
This feature involves 3 different libraries. To install them, please follow the instructions below: **Inference-only backend**s:
- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup) * `dynamo.optimize("ofi")` - Uses Torchscript optimize_for_inference. [Read more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html)
- [Functorch installation](https://github.com/pytorch/functorch#install) * `dynamo.optimize("fx2trt")` - Uses Nvidia TensorRT for inference optimizations. [Read more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)
- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation) * `dynamo.optimize("onnxrt")` - Uses ONNXRT for inference on CPU/GPU. [Read more](https://onnxruntime.ai/)
* `dynamo.optimize("ipex")` - Uses IPEX for inference on CPU. [Read more](https://github.com/intel/intel-extension-for-pytorch)
...@@ -144,7 +144,6 @@ from .utils import ( ...@@ -144,7 +144,6 @@ from .utils import (
is_ipex_available, is_ipex_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_tensorrt_fx_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchdynamo_available, is_torchdynamo_available,
logging, logging,
...@@ -637,32 +636,8 @@ class Trainer: ...@@ -637,32 +636,8 @@ class Trainer:
self._memory_tracker.stop_and_update_metrics() self._memory_tracker.stop_and_update_metrics()
# torchdynamo # torchdynamo
if args.torchdynamo: if args.torchdynamo is not None and not is_torchdynamo_available():
if not is_torchdynamo_available(): raise RuntimeError("Using torchdynamo requires a nighly install of PyTorch.")
raise RuntimeError("Torchdynamo is not installed.")
import torchdynamo
from torchdynamo.optimizations import backends
def get_ctx():
# Normal
if args.torchdynamo == "eager":
return torchdynamo.optimize("eager")
elif args.torchdynamo == "nvfuser":
return torchdynamo.optimize("aot_nvfuser")
# TensorRT
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
if not is_torch_tensorrt_fx_available():
raise RuntimeError("Torch-TensorRT FX path is not installed.")
if args.torchdynamo == "fx2trt-fp16":
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
elif args.torchdynamo == "fx2trt":
return torchdynamo.optimize(backends.fx2trt_compiler)
else:
raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")
self.ctx_manager_torchdynamo = get_ctx()
else:
self.ctx_manager_torchdynamo = contextlib.nullcontext()
def add_callback(self, callback): def add_callback(self, callback):
""" """
...@@ -1339,6 +1314,10 @@ class Trainer: ...@@ -1339,6 +1314,10 @@ class Trainer:
return model return model
def _wrap_model(self, model, training=True, dataloader=None): def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torchdynamo is not None:
import torch._dynamo as dynamo
model = dynamo.optimize(self.args.torchdynamo)(model)
if self.args.use_ipex: if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype) model = self.ipex_optimize_model(model, training, dtype=dtype)
...@@ -2494,18 +2473,7 @@ class Trainer: ...@@ -2494,18 +2473,7 @@ class Trainer:
""" """
A helper wrapper to group together context managers. A helper wrapper to group together context managers.
""" """
return ContextManagers( return self.autocast_smart_context_manager()
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
]
)
def torchdynamo_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
return self.ctx_manager_torchdynamo
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
""" """
......
...@@ -73,6 +73,20 @@ log_levels = logging.get_log_levels_dict().copy() ...@@ -73,6 +73,20 @@ log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1) trainer_log_levels = dict(**log_levels, passive=-1)
DYNAMO_BACKENDS = [
"eager",
"aot_eager",
"inductor",
"nvfuser",
"aot_nvfuser",
"aot_cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"ipex",
]
def default_logdir() -> str: def default_logdir() -> str:
""" """
Same default as PyTorch Same default as PyTorch
...@@ -485,8 +499,8 @@ class TrainingArguments: ...@@ -485,8 +499,8 @@ class TrainingArguments:
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training distributed training
torchdynamo (`str`, *optional*): torchdynamo (`str`, *optional*):
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager", If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`,
"nvfuser]. This is an experimental API and subject to change. `"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
ray_scope (`str`, *optional*, defaults to `"last"`): ray_scope (`str`, *optional*, defaults to `"last"`):
The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
then use the last checkpoint of all trials, compare those, and select the best one. However, other options then use the last checkpoint of all trials, compare those, and select the best one. However, other options
...@@ -969,15 +983,8 @@ class TrainingArguments: ...@@ -969,15 +983,8 @@ class TrainingArguments:
torchdynamo: Optional[str] = field( torchdynamo: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": ( "help": "Sets up the backend compiler for TorchDynamo.",
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to" "choices": DYNAMO_BACKENDS,
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
"choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
}, },
) )
ray_scope: Optional[str] = field( ray_scope: Optional[str] = field(
......
...@@ -445,7 +445,14 @@ def is_torch_tpu_available(check_device=True): ...@@ -445,7 +445,14 @@ def is_torch_tpu_available(check_device=True):
def is_torchdynamo_available(): def is_torchdynamo_available():
return importlib.util.find_spec("torchdynamo") is not None if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
return True
except Exception:
return False
def is_torch_tensorrt_fx_available(): def is_torch_tensorrt_fx_available():
......
...@@ -1839,20 +1839,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1839,20 +1839,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# 4. TorchDynamo fx2trt # 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt") trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate() metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset() torchdynamo.reset()
# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
torchdynamo.reset()
@require_torch_non_multi_gpu @require_torch_non_multi_gpu
@require_torchdynamo @require_torchdynamo
def test_torchdynamo_memory(self): def test_torchdynamo_memory(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment