"vscode:/vscode.git/clone" did not exist on "69c1c3520085c3f1724bf3329d189b56b320a1b3"
base.py 9.35 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE base custom ops"""
5
6
import os
import re
7
import warnings
8
9
from abc import ABCMeta, abstractmethod
from functools import partial
10
from packaging import version
11

12
from jax.extend import core
13
14
15
16
17
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src import dispatch

18
19
20
21
22
23
24
25
import jax
import transformer_engine_jax

if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

26
27
28
29
30
31

class BasePrimitive(metaclass=ABCMeta):
    """
    jax primitive
    """

32
33
    name = None

34
35
36
    _is_enabled = True

    # Default list of primitives to disable for all recipes
37
    _default_disable_names = []
38

39
40
41
    @classmethod
    def enabled(cls):
        """
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        Determines if a custom call is enabled based on a state variable and environment variables.
        Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern),
        and finally to the internal state `_is_enabled` if neither is set.

        Environment Variables:
            1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives.
               - Example 1 (global enable): 'true' enables all primitives.
               - Example 2 (global disable): 'false' disables all primitives.
               - Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state.
                 Note that the default state is set at class level based on _default_disable_names.
            2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names.
               - Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive.
               - A deprecation warning is raised if used; it will be removed in future releases.

        Behavior:
            1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value.
            2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching.
            3. If neither is set, falls back to the internal state `_is_enabled`.
60
        """
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

        # Check new key/value environment variable first
        custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS")
        if custom_calls_str is not None:
            custom_calls_str = custom_calls_str.strip()
            if custom_calls_str.lower() == "true":
                return True
            if custom_calls_str.lower() == "false":
                return False

            # Parse key=value pairs
            settings = {}
            for pair in custom_calls_str.split(","):
                pair = pair.strip()
                if "=" in pair:
                    key, value = pair.split("=", 1)
                    key = key.strip()
                    value = value.strip().lower()
                    settings[key] = value == "true"
            if cls.__name__ in settings:
                return settings[cls.__name__]

        # Check old regex environment variable (deprecated)
        pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE")
        if pattern_str is not None:
            warnings.warn(
                "NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use"
                " NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.,"
                " 'DBiasQuantizePrimitive=false').",
                DeprecationWarning,
            )
            pattern = re.compile(pattern_str)
            env_enabled = pattern.fullmatch(cls.__name__) is not None
            return env_enabled

        # If no environment variable is set, fall back to the internal state
        return cls._is_enabled

    @classmethod
    def set_enabled(cls, enabled: bool):
        """
        Sets the enabled state for this primitive.
        """
        cls._is_enabled = enabled
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

    @classmethod
    def outer_abstract(cls, *args, **kwargs):
        """
        optional abstract wrapper to eliminate workspace tensors
        """
        return cls.abstract(*args, **kwargs)

    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

137
138
139
140
141
142
143
    @classmethod
    def outer_impl(cls, *args, **kwargs):
        """
        to describe implementation for outer primitive
        """
        return cls.impl(*args, **kwargs)

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

168
169
170
171
172
173
174
175
176
    @staticmethod
    @abstractmethod
    def shardy_sharding_rule(*args):
        """
        Returns the sharding rule for this primitive.
        """
        del args
        return "... -> ..."

177

178
179
180
181
# Registry to store all registered primitive classes
_primitive_registry = {}


182
183
def register_primitive(cls):
    """
184
    Register a JAX primitive and add it to the internal registry.
185
    """
186
187
188
189
190
    _primitive_registry[cls.__name__] = cls

    # Set default disabled state at class level based on _default_disable_names
    if cls.__name__ in BasePrimitive._default_disable_names:
        cls.set_enabled(False)
191
192
193
194
195
196
197
198
199

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
200
    mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
201
202
203
204
205
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
    outer_p.multiple_results = cls.multiple_results
206
    outer_p.def_impl(cls.outer_impl)
207
208
209
    outer_p.def_abstract_eval(cls.outer_abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
210
    outer_p_lower.def_partition(
211
212
213
        infer_sharding_from_operands=cls.infer_sharding_from_operands,
        partition=cls.partition,
        sharding_rule=cls.shardy_sharding_rule,
214
215
216
217
    )
    mlir.register_lowering(
        outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
    )
218
    cls.outer_primitive = outer_p
219
220
221
222


for _name, _value in transformer_engine_jax.registrations().items():
    ffi.register_ffi_target(_name, _value, platform="CUDA")
223
224
225
226
227
228


def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
    """
    Helper function to manage primitive states by name without modifying environment variables.
    Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
229
    This helper is used in the get_quantize_config().initialize() methods.
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

    Args:
        enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
        disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None.
        disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False.

    Note:
        1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied.
        2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last.
    """

    enable_set = set(enable_names or [])
    disable_set = set(disable_names or [])

    if disable_all_first:
        for name, cls in _primitive_registry.items():
            if (
                isinstance(cls, type)
                and issubclass(cls, BasePrimitive)
                and cls is not BasePrimitive
            ):
                cls.set_enabled(False)

    # Apply enables
    for name in enable_set:
        cls = _primitive_registry.get(name)
        if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
            cls.set_enabled(True)
        else:
            raise ValueError(f"Primitive not found in registry: {name}")

    # Apply disables (overrides enables if there's a conflict)
    for name in disable_set:
        cls = _primitive_registry.get(name)
        if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
            cls.set_enabled(False)
        else:
            raise ValueError(f"Primitive not found in registry: {name}")