# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Top level package""" # pylint: disable=unused-import import os from importlib import metadata import transformer_engine.common try: from . import pytorch except ImportError: pass except FileNotFoundError as e: if "Could not find shared object file" in str(e): if os.getenv("NVTE_FRAMEWORK") is None: # If we got here, we could import `torch` but could not load the framework extension. # This can happen when a user wants to work only with `transformer_engine.jax` on a system that # also has a PyTorch installation. In order to enable that use case, we issue a warning here # about the missing PyTorch extension in case the user hasn't set NVTE_FRAMEWORK. import warnings warnings.warn( "Detected a PyTorch installation but could not find the shared object file for the " "Transformer Engine PyTorch extension library. If this is not intentional, please " "reinstall Transformer Engine with `pip install transformer_engine[pytorch]` or " "build from source with `NVTE_FRAMEWORK=pytorch`.", category=RuntimeWarning, ) elif os.getenv("NVTE_FRAMEWORK") in ("pytorch", "all"): raise e try: from . import jax except ImportError: pass except FileNotFoundError as e: if "Could not find shared object file" in str(e): if os.getenv("NVTE_FRAMEWORK") is None: # If we got here, we could import `jax` but could not load the framework extension. # This can happen when a user wants to work only with `transformer_engine.pytorch` on a system # that also has a Jax installation. In order to enable that use case, we issue a warning here # about the missing Jax extension in case the user hasn't set NVTE_FRAMEWORK. import warnings warnings.warn( "Detected a Jax installation but could not find the shared object file for the " "Transformer Engine Jax extension library. If this is not intentional, please " "reinstall Transformer Engine with `pip install transformer_engine[jax]` or " "build from source with `NVTE_FRAMEWORK=jax`.", category=RuntimeWarning, ) elif os.getenv("NVTE_FRAMEWORK") in ("jax", "all"): raise e __version__ = str(metadata.version("transformer_engine"))