base.py 3.77 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE base custom ops"""
5
6
import os
import re
7
8
from abc import ABCMeta, abstractmethod
from functools import partial
9
from packaging import version
10

11
from jax.extend import core
12
13
14
15
16
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src import dispatch

17
18
19
20
21
22
23
24
import jax
import transformer_engine_jax

if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

25
26
27
28
29
30

class BasePrimitive(metaclass=ABCMeta):
    """
    jax primitive
    """

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    name = None

    @classmethod
    def enabled(cls):
        """
        A custom call is marked as disabled if the `cls.name` does not fully match the
        `NVTE_JAX_CUSTOM_CALLS_RE` pattern.
        By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
        For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`.
        """
        pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
        pattern = re.compile(pattern)
        is_enabled = pattern.fullmatch(cls.name) is not None
        return is_enabled

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

    @classmethod
    def outer_abstract(cls, *args, **kwargs):
        """
        optional abstract wrapper to eliminate workspace tensors
        """
        return cls.abstract(*args, **kwargs)

    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

101

102
103
104
105
106
107
108
109
110
111
112
113
114
def register_primitive(cls):
    """
    register jax primitive
    """

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
115
    mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
116
117
118
119
120
121
122
123
124
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.outer_abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
125
126
127
128
129
130
    outer_p_lower.def_partition(
        infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition
    )
    mlir.register_lowering(
        outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
    )
131
    cls.outer_primitive = outer_p
132
133
134
135


for _name, _value in transformer_engine_jax.registrations().items():
    ffi.register_ffi_target(_name, _value, platform="CUDA")