Unverified Commit 9d9c3a04 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Expose `multi_tensor_*` kernels (#907)



* expose multi_tensor_* kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent 6a2dd785
......@@ -3,6 +3,15 @@
# See LICENSE for license information.
"""Fused optimizers and multi-tensor kernels."""
from transformer_engine_torch import (
multi_tensor_scale,
multi_tensor_l2norm,
multi_tensor_unscale_l2norm,
multi_tensor_adam,
multi_tensor_adam_capturable,
multi_tensor_adam_capturable_master,
multi_tensor_sgd,
)
from .fused_adam import FusedAdam
from .fused_sgd import FusedSGD
from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment