Unverified Commit 897a8dd8 authored by Animesh Jain's avatar Animesh Jain Committed by GitHub
Browse files

Support compilation via Torchdynamo, AOT Autograd, NVFuser (#17308)



* Support compilation via Torchdynamo, AOT Autograd, NVFuser

* Address comments

* Lint

* Stas comments - missing quality test

* Lintere

* Quality test

* Doc lint

* Reset CUDA peak mem

* Add CustomTrainer

* require a single gpu
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 31484afb
...@@ -70,6 +70,7 @@ from .utils import ( ...@@ -70,6 +70,7 @@ from .utils import (
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdynamo_available,
is_vision_available, is_vision_available,
) )
...@@ -464,6 +465,11 @@ else: ...@@ -464,6 +465,11 @@ else:
jax_device = None jax_device = None
def require_torchdynamo(test_case):
"""Decorator marking a test that requires TorchDynamo"""
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
def require_torch_gpu(test_case): def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch.""" """Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
......
...@@ -139,8 +139,10 @@ from .utils import ( ...@@ -139,8 +139,10 @@ from .utils import (
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_tpu_available, is_torch_tpu_available,
is_torchdynamo_available,
logging, logging,
) )
from .utils.generic import ContextManagers
_is_torch_generator_available = False _is_torch_generator_available = False
...@@ -2172,6 +2174,32 @@ class Trainer: ...@@ -2172,6 +2174,32 @@ class Trainer:
return inputs return inputs
def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
"""
return ContextManagers(
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
]
)
def torchdynamo_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
ctx_manager = contextlib.nullcontext()
if is_torchdynamo_available():
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
if self.args.torchdynamo == "eager":
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
return ctx_manager
def autocast_smart_context_manager(self): def autocast_smart_context_manager(self):
""" """
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
...@@ -2213,7 +2241,7 @@ class Trainer: ...@@ -2213,7 +2241,7 @@ class Trainer:
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
with self.autocast_smart_context_manager(): with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
...@@ -2907,7 +2935,7 @@ class Trainer: ...@@ -2907,7 +2935,7 @@ class Trainer:
logits = smp_nested_concat(logits_mb) logits = smp_nested_concat(logits_mb)
else: else:
if has_labels: if has_labels:
with self.autocast_smart_context_manager(): with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach() loss = loss.mean().detach()
...@@ -2917,7 +2945,7 @@ class Trainer: ...@@ -2917,7 +2945,7 @@ class Trainer:
logits = outputs[1:] logits = outputs[1:]
else: else:
loss = None loss = None
with self.autocast_smart_context_manager(): with self.compute_loss_context_manager():
outputs = model(**inputs) outputs = model(**inputs)
if isinstance(outputs, dict): if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
......
...@@ -183,7 +183,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -183,7 +183,7 @@ class Seq2SeqTrainer(Trainer):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad(): with torch.no_grad():
with self.autocast_smart_context_manager(): with self.compute_loss_context_manager():
outputs = model(**inputs) outputs = model(**inputs)
if has_labels: if has_labels:
if self.label_smoother is not None: if self.label_smoother is not None:
......
...@@ -450,6 +450,9 @@ class TrainingArguments: ...@@ -450,6 +450,9 @@ class TrainingArguments:
full_determinism (`bool`, *optional*, defaults to `False`) full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training distributed training
torchdynamo (`str`, *optional*):
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
"nvfuser]. This is an experimental API and subject to change.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -881,6 +884,20 @@ class TrainingArguments: ...@@ -881,6 +884,20 @@ class TrainingArguments:
) )
}, },
) )
torchdynamo: Optional[str] = field(
default=None,
metadata={
"help": (
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
"choices": ["eager", "nvfuser"],
},
)
def __post_init__(self): def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then). # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
......
...@@ -130,6 +130,7 @@ from .import_utils import ( ...@@ -130,6 +130,7 @@ from .import_utils import (
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdynamo_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_vision_available, is_vision_available,
requires_backends, requires_backends,
......
...@@ -376,6 +376,10 @@ def is_torch_tpu_available(): ...@@ -376,6 +376,10 @@ def is_torch_tpu_available():
return importlib.util.find_spec("torch_xla.core.xla_model") is not None return importlib.util.find_spec("torch_xla.core.xla_model") is not None
def is_torchdynamo_available():
return importlib.util.find_spec("torchdynamo") is not None
def is_datasets_available(): def is_datasets_available():
return _datasets_available return _datasets_available
......
...@@ -62,6 +62,7 @@ from transformers.testing_utils import ( ...@@ -62,6 +62,7 @@ from transformers.testing_utils import (
require_torch_non_multi_gpu, require_torch_non_multi_gpu,
require_torch_tf32, require_torch_tf32,
require_torch_up_to_2_gpus, require_torch_up_to_2_gpus,
require_torchdynamo,
require_wandb, require_wandb,
slow, slow,
) )
...@@ -1594,6 +1595,100 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1594,6 +1595,100 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# perfect world: fp32_init/2 == fp16_eval # perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_full_eval(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()
bs = 8
eval_len = 16 * n_gpus
# make the params are somewhat big so that there will be enough RAM consumed to be able to
# measure things. We should get about 64KB for a+b in fp32
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001
# 1. Default - without TorchDynamo
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
metrics = trainer.evaluate()
original_eval_loss = metrics["eval_loss"]
del trainer
# 2. TorchDynamo eager
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
del trainer
# 3. TorchDynamo nvfuser
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
x = inputs["x"]
output = model(x)
if self.args.n_gpu == 1:
return output.mean()
return output
class MyModule(torch.nn.Module):
"""Simple module that does aggressive fusion"""
def __init__(self):
super().__init__()
def forward(self, x):
for _ in range(20):
x = torch.nn.functional.relu(x)
return x
mod = MyModule()
# 1. Default - without TorchDynamo
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
trainer = CustomTrainer(model=mod)
# warmup
for _ in range(10):
orig_loss = trainer.training_step(mod, {"x": a})
torch.cuda.reset_peak_memory_stats()
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer
# Reset the peak for another measurement
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# 2. TorchDynamo nvfuser
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
trainer = CustomTrainer(model=mod, args=args)
# warmup
for _ in range(10):
loss = trainer.training_step(mod, {"x": a})
torch.cuda.reset_peak_memory_stats()
loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
del trainer
# Functional check
self.assertAlmostEqual(loss, orig_loss)
# AOT Autograd recomputaion and nvfuser recomputation optimization
# aggressively fuses the operations and reduce the memory footprint.
self.assertGreater(orig_peak_mem, peak_mem * 2)
@require_torch_gpu @require_torch_gpu
@require_torch_bf16 @require_torch_bf16
def test_bf16_full_eval(self): def test_bf16_full_eval(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment