Unverified Commit 0a1499fa authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] Dynamo ONNX export support (#1497)



* some initial code
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* onnx support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* mxfp8 support
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixed returning layernorm etc
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* formatting
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* license fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests passing
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lint
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* added pip install to test.sh
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/pytorch/export.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* float8currentscaling quantizer exception
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* added to wheels
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* onnx versions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* installations in tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* onnxscript version chnage
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix CI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>

* Update build.yml
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update pytorch.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>
Signed-off-by: default avatarroot <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarroot <root@prenyx0221.a51.clusters.nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@gmail.com>
parent c0c12e20
......@@ -136,6 +136,34 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
pass
def create_tensor_from_data(
self,
data: torch.Tensor,
scale_inv: torch.Tensor,
fake_dtype: torch.dtype,
fp8_dtype: TE_DType = tex.DType.kFloat8E4M3,
) -> MXFP8Tensor:
"""Create a new MXFP8Tensor from data and scale_inv."""
return MXFP8Tensor(
shape=data.shape,
dtype=fake_dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=self,
)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
if tensor.dtype != torch.float32:
tensor = tensor.to(dtype=torch.float32)
data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor)
return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32)
def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor:
return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return MXFP8BlockScaling
......
......@@ -250,6 +250,12 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
return copy.copy(self)
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
def onnx_dequantize(self, tensor) -> torch.Tensor:
"""Symbolic function for ONNX export"""
@abc.abstractmethod
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
......
......@@ -33,6 +33,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type,
)
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -814,7 +815,12 @@ class TransformerLayer(torch.nn.Module):
return output
def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias is not None and bias.numel() != 0:
if (
drop_path is None
and bias is not None
and bias.numel() != 0
and not is_in_onnx_export_mode()
):
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment