"docs/vscode:/vscode.git/clone" did not exist on "e4d8f517b93db57fd0ad2fa80a74549cf9e42488"
Commit 2564f0c2 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

fix jit trace error for model forward sequence is not aligned with jit.trace...


fix jit trace error for model forward sequence is not aligned with jit.trace tuple input sequence, update related doc (#19891)

* fix jit trace error for classification usecase, update related doc
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* add implementation in torch 1.14.0
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* update_doc
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* update_doc
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 737bff6a
......@@ -22,17 +22,27 @@ For a gentle introduction to TorchScript, see the Introduction to [PyTorch Torch
### IPEX Graph Optimization with JIT-mode
Intel® Extension for PyTorch provides further optimizations in jit mode for Transformers series models. It is highly recommended for users to take advantage of Intel® Extension for PyTorch with jit mode. Some frequently used operator patterns from Transformers models are already supported in Intel® Extension for PyTorch with jit mode fusions. Those fusion patterns like Multi-head-attention fusion, Concat Linear, Linear+Add, Linear+Gelu, Add+LayerNorm fusion and etc. are enabled and perform well. The benefit of the fusion is delivered to users in a transparent fashion. According to the analysis, ~70% of most popular NLP tasks in question-answering, text-classification, and token-classification can get performance benefits with these fusion patterns for both Float32 precision and BFloat16 Mixed precision.
Check more detailed information for [IPEX Graph Optimization](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/features/graph_optimization.html).
Check more detailed information for [IPEX Graph Optimization](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/features/graph_optimization.html).
#### IPEX installation:
IPEX release is following PyTorch, check the approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/).
### Usage of JIT-mode
To enable jit mode in Trainer, users should add `jit_mode_eval` in Trainer command arguments.
To enable JIT-mode in Trainer for evaluaion or prediction, users should add `jit_mode_eval` in Trainer command arguments.
<Tip warning={true}>
for PyTorch >= 1.14.0. JIT-mode could benefit any models for prediction and evaluaion since dict input is supported in jit.trace
for PyTorch < 1.14.0. JIT-mode could benefit models whose forward parameter order matches the tuple input order in jit.trace, like question-answering model
In the case where the forward parameter order does not match the tuple input order in jit.trace, like text-classification models, jit.trace will fail and we are capturing this with the exception here to make it fallback. Logging is used to notify users.
</Tip>
Take an example of the use cases on [Transformers question-answering](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering)
- Inference using jit mode on CPU:
<pre>python run_qa.py \
--model_name_or_path csarron/bert-base-uncased-squad-v1 \
......
......@@ -19,7 +19,7 @@ IPEX is optimized for CPUs with AVX-512 or above, and functionally works for CPU
Low precision data type BFloat16 has been natively supported on the 3rd Generation Xeon® Scalable Processors (aka Cooper Lake) with AVX512 instruction set and will be supported on the next generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix Extensions (Intel® AMX) instruction set with further boosted performance. The Auto Mixed Precision for CPU backend has been enabled since PyTorch-1.10. At the same time, the support of Auto Mixed Precision with BFloat16 for CPU and BFloat16 optimization of operators has been massively enabled in Intel® Extension for PyTorch, and partially upstreamed to PyTorch master branch. Users can get better performance and user experience with IPEX Auto Mixed Precision.
Check more detailed information for [Auto Mixed Precision](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/features/amp.html).
Check more detailed information for [Auto Mixed Precision](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/features/amp.html).
### IPEX installation:
......@@ -37,7 +37,13 @@ For PyTorch-1.11:
pip install intel_extension_for_pytorch==1.11.200+cpu -f https://software.intel.com/ipex-whl-stable
```
Check more approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/installation.html).
For PyTorch-1.12:
```
pip install intel_extension_for_pytorch==1.12.300+cpu -f https://software.intel.com/ipex-whl-stable
```
Check more approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/installation.html).
### Usage in Trainer
To enable auto mixed precision with IPEX in Trainer, users should add `use_ipex`, `bf16` and `no_cuda` in training command arguments.
......
......@@ -1251,20 +1251,34 @@ class Trainer:
if dataloader is None:
logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
return model
jit_inputs = []
example_batch = next(iter(dataloader))
for key in example_batch:
example_tensor = torch.ones_like(example_batch[key])
jit_inputs.append(example_tensor)
jit_inputs = tuple(jit_inputs)
example_batch = self._prepare_inputs(example_batch)
try:
jit_model = model.eval()
with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
else:
jit_model = torch.jit.trace(
jit_model,
example_kwarg_inputs={key: example_batch[key] for key in example_batch},
strict=False,
)
else:
jit_inputs = []
for key in example_batch:
example_tensor = torch.ones_like(example_batch[key])
jit_inputs.append(example_tensor)
jit_inputs = tuple(jit_inputs)
jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
jit_model = torch.jit.freeze(jit_model)
jit_model(**example_batch)
jit_model(**example_batch)
model = jit_model
except (RuntimeError, TypeError) as e:
self.use_cpu_amp = False
self.use_cuda_amp = False
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
return model
......@@ -1296,9 +1310,6 @@ class Trainer:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
if self.args.jit_mode_eval:
model = self.torch_jit_model_eval(model, dataloader, training)
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
......@@ -1321,6 +1332,9 @@ class Trainer:
if self.args.n_gpu > 1:
model = nn.DataParallel(model)
if self.args.jit_mode_eval:
model = self.torch_jit_model_eval(model, dataloader, training)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
......@@ -2460,7 +2474,7 @@ class Trainer:
"""
return self.ctx_manager_torchdynamo
def autocast_smart_context_manager(self):
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = None):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
......@@ -2468,9 +2482,9 @@ class Trainer:
if self.use_cuda_amp or self.use_cpu_amp:
if is_torch_greater_or_equal_than_1_10:
ctx_manager = (
torch.cpu.amp.autocast(dtype=self.amp_dtype)
torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
if self.use_cpu_amp
else torch.cuda.amp.autocast(dtype=self.amp_dtype)
else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
)
else:
ctx_manager = torch.cuda.amp.autocast()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment