"vllm/model_executor/parallel_utils/utils.py" did not exist on "2f49f155858faaf82bfd076a821497e41e961658"
__init__.py 1.64 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""Transformer Engine bindings for JAX.

This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.

The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training

All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
22

23
# pylint: disable=wrong-import-position
24

25
26
27
28
# This unused import is needed because the top level `transformer_engine/__init__.py`
# file catches an `ImportError` as a guard for cases where the given framework's
# extensions are not available.
import jax
29

30
from transformer_engine.common import load_framework_extension
31

32
load_framework_extension("jax")
33

34
from . import flax
35
36
from . import quantize

37
from .quantize import autocast, fp8_autocast, update_collections
38
from .quantize import NVTE_FP8_COLLECTION_NAME
39

40
from .sharding import MeshResource
41

42
43
44
from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum

45
46

__all__ = [
47
    "NVTE_FP8_COLLECTION_NAME",
48
    "autocast",
49
    "fp8_autocast",
50
    "update_collections",
51
52
    "MeshResource",
    "flax",
53
    "quantize",
54
]