api.py 21.7 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

"""API definition for nvidia-dlframework-inspect."""

import copy
8
import warnings
9
from typing import Dict, Union, Tuple, Optional
10
11
12
13
14
15
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry

import torch

from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
16
from transformer_engine.pytorch.tensor import get_all_tensor_types
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor


class TEConfigAPIMapper(BaseConfigAPIMapper):
    """Class responsible for determining which NV DLFW Inspect API should be run for each tensor and gemm."""

    def parse_config_and_api(self, config, **kwargs):
        """Process the config and returns True if the config and api args match, along with processed config."""
        processed_config = None
        config_copy = copy.deepcopy(config)
        gemm_parsing = kwargs.get("gemm_parsing", False)
        tensor_parsing = kwargs.get("tensor_parsing", False)

        if gemm_parsing:
            # parse with GEMM and/or tensor
            processed_config = self._process_transformer_engine_config(config_copy, **kwargs)
        elif tensor_parsing:
            # parse with only tensor
            processed_config = self._process_tensor_config(config_copy, kwargs["tensor_name"])

        if not processed_config:
            return False, None

        if "enabled" in processed_config:
            processed_config.pop("enabled")
        return True, processed_config

    def _validate_gemm(self, gemm):
        assert gemm in ["fprop", "wgrad", "dgrad"], (
            f"[NVTORCH INSPECT ERROR] Invalid gemm: {gemm}. It must be one of the ['fprop',"
            " 'wgrad', 'dgrad']."
        )

    def _process_transformer_engine_config(self, config, **kwargs):
        """
        Return config specific to a particular tensor name and gemm that matches the api args.
        """
        if "gemms_struct" in config:
            for cfg in config["gemms_struct"]:
                self._validate_gemm(cfg["gemm"])
                if cfg["gemm"] == kwargs["gemm"]:
                    if kwargs["tensor_parsing"]:
                        cfg = self._process_tensor_config(cfg, kwargs["tensor_name"])
                        if not cfg:
                            return None
                    cfg_copy = copy.deepcopy(cfg)
                    config.pop("gemms_struct")
                    assert (
                        "enabled" not in cfg_copy
                    ), "[NVTORCH INSPECT ERROR] Enabled field should not be part of gemms_struct"
                    config.update(cfg_copy)
                    return config
            return None
        if "gemms" in config:
            for gemm in config["gemms"]:
                self._validate_gemm(gemm)
            if kwargs["gemm"] in config["gemms"]:
                if kwargs["tensor_parsing"]:
                    cfg = self._process_tensor_config(config, kwargs["tensor_name"])
                    if not cfg:
                        return None
                config["gemm"] = kwargs["gemm"]
                config.pop("gemms")
                return config
            return None
        raise ValueError(
            "[NVTORCH INSPECT ERROR] Provide 'gemms_struct: List[Dict]' or 'gemms: List[str]'"
            " in the config yaml"
        )


required_kwargs = {
    "fp8_gemm_enabled": ["gemm"],
    "modify_tensor_enabled": ["tensor_name", "gemm"],
    "modify_tensor": ["tensor_name", "gemm"],
    "inspect_tensor": ["tensor_name"],
    "inspect_tensor_postquantize": ["tensor_name"],
    "inspect_tensor_enabled": ["tensor_name"],
    "inspect_tensor_postquantize_enabled": ["tensor_name"],
    "default": ["tensor_name", "gemm"],
}


# pylint: disable=unused-argument
class TEDefaultFeatures:
    """Transformer Engine API calls default behavior."""

105
106
107
108
109
110
111
    def fp8_gemm_enabled(
        self,
        config: Dict,
        layer_name: str,
        gemm: str,
        iteration: int,
    ) -> bool | Tuple[bool, Optional[int]]:
112
113
114
115
116
117
        """
        If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
        then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
        If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
        the result of this call does not matter.

118
        This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled.
119
120
121
        It can return (bool, None) if the feature will never be enabled for that layer and gemm.
        Returning the next enabled iteration can help optimize CPU usage.

122
123
124
125
126
127
128
129
130
131
132
133
134
135
        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        gemm: str
            one of [`fprop`, `dgrad`, `wgrad`],
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.

        Returns
        -------

136
        Union[bool, Tuple[bool, Optional[int]]] - default is (True, None)
137
        """
138
        return True, None  # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
139
140
141
142
143
144
145
146

    def modify_tensor_enabled(
        self,
        config: Dict,
        layer_name: str,
        gemm: str,
        tensor_name: str,
        iteration: int,
147
    ) -> bool | Tuple[bool, Optional[int]]:
148
        """
149
150
151
152
153
154
155
156
        It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name.
        It has **higher priority** than fp8_gemm; if *modify_tensor_enabled* returns True or (True, next_enabled_iter),
        then modify_tensor call is invoked for the respective tensor no matter what.

        This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
        It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor.
        Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large.
        Returning only a bool is deprecated.
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        gemm: str
            one of [`fprop`, `dgrad`, `wgrad`],
        tensor_name: str
            one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.

        Returns
        -------

174
        Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
175
        """
176
        return False, None
177
178
179
180
181
182
183
184
185
186
187

    def modify_tensor(
        self,
        config: Dict,
        layer_name: str,
        gemm: str,
        tensor_name: str,
        tensor: torch.Tensor,
        default_quantizer: Quantizer,
        iteration: int,
        out: Union[torch.Tensor, QuantizedTensor],
188
    ) -> torch.Tensor | QuantizedTensor | None:
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        """
        It allows tensor modification.
        For example, feature `FakeQuant` uses it to emulate casting to FP8.
        It can be invoked at most once for each tensor within a given GEMM operation.

        This call is invoked if `modify_tensor_enabled` returns `True` and the feature is enabled for the *tensor_name* and *gemm*.
        Then it is called **instead of** the default quantization.

        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        tensor: torch.Tensor
            tensor in high precision,
        gemm: str
            one of [`fprop`, `dgrad`, `wgrad`],
        tensor_name: str
            one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
        default_quantizer : Quantizer
            quantizer which is used to cast the tensor to lower precision
            if *modify_tensor* is not invoked. For example,
            feature per tensor scale uses it to obtain FP8 dtype of the tensor.
            If the recipe indicates that the tensor is not cast - for example,
            if running without FP8 autocast, then `default_quantizer=None`,
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.
        out: Union[torch.Tensor, QuantizedTensor]
            output tensor, used in the weight caching mechanism.


        Returns
        -------

        Union[torch.Tensor, transformer_engine.pytorch.QuantizerTensor, None]
            can be `torch.Tensor` or one of the Transformer Engine's `QuantizedTensor` -
            the rule is that both tensors returned for each GEMM should have the same type.
            If both are `Float8Tensor`, then GEMM is run in FP8.
            If both are `torch.Tensor`, GEMM is run in high precision.
            Please take that into account especially if only one tensor of the GEMM
            is processed by the `modify_tensor()`. For example, `FakeQuant`
            disabled FP8 GEMM to ensure that the second tensor is also in high precision.
            If the tensor is not the input for any GEMM - namely  `output`,
            `wgrad` and `dgrad` - the return type would match the input type.
        Should return `None` if `out` is not `None`.

        """
        raise NotImplementedError(
            "modify_tensor_enabled() returned True, modify_tensor() was invoked, but it is not"
            " handled by any API."
        )

    def inspect_tensor(
        self,
        config: Dict,
        layer_name: str,
        tensor_name: str,
        tensor: torch.Tensor,
248
249
250
        rowwise_quantized_tensor: Optional[torch.Tensor],
        columnwise_quantized_tensor: Optional[torch.Tensor],
        quantizer: Optional[Quantizer],
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        iteration: int,
        tp_group: torch.distributed.ProcessGroup,
    ) -> None:
        """
        The feature is invoked if *inspect_tensor_enabled* returns `True`. It can be used to obtain information on the high precision tensor. For example, it is run by the `LogTensorStats` feature.

        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        tensor_name: str
            one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
        tensor: torch.Tensor
            tensor in high precision,
267
268
269
270
271
272
        rowwise_quantized_tensor: Optional[torch.Tensor]
            rowwise quantized tensor,
        columnwise_quantized_tensor: Optional[torch.Tensor]
            columnwise quantized tensor,
        quantizer: Optional[Quantizer]
            quantizer,
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.
        tp_group: torch.distributed.ProcessGroup
            process group for the tensor parallel group. This is used for weight statistics reduction.
            This is not reduction group from debug_api.

        Returns
        -------

        Should return nothing.
        """

    def inspect_tensor_postquantize(
        self,
        config: Dict,
        layer_name: str,
        tensor_name: str,
        tensor: torch.Tensor,
        iteration: int,
        tp_group: torch.distributed.ProcessGroup,
293
        rowwise: bool,
294
295
    ) -> None:
        """
296
297
298

        This is deprecated call, we advise to use *inspect_tensor* instead.

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.

        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        tensor_name: str
            one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
        tensor: torch.Tensor
            tensor in fp8 or processed tensor after the modify_tensor call,
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.
        tp_group: torch.distributed.ProcessGroup
            process group for the tensor parallel group. This is used for weight statistics reduction.
            This is not reduction group from debug_api.

        Returns
        -------

        Should return nothing.
        """

    def inspect_tensor_enabled(
        self,
        config: Dict,
        layer_name: str,
        tensor_name: str,
        iteration: int,
329
    ) -> bool | Tuple[bool, Optional[int]]:
330
        """
331
332
333
334
335
336
337
        It is a routing call, which is run at the initialization of the layer.
        Determines if *inspect_tensor* for a given GEMM and tensor will be invoked.

        This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
        It can return (bool, None) if the feature will never be enabled for that layer and tensor.
        Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large.
        Returning only a bool is deprecated.
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352

        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        tensor_name: str
            one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.

        Returns
        -------

353
        Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
354
        """
355
        return False, None
356
357
358
359
360
361
362
363

    def inspect_tensor_postquantize_enabled(
        self,
        config: Dict,
        layer_name: str,
        gemm: str,
        tensor_name: str,
        iteration: int,
364
    ) -> bool | Tuple[bool, Optional[int]]:
365
        """
366
367
        This is deprecated call, we advise to use *inspect_tensor* and *inspect_tensor_enabled* instead.

368
        It is a routing call, which is run at the initialization of the layer.
369
370
371
372
373
374
375
        Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked.

        This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
        It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name.
        Returning the next enabled iteration can help optimize CPU usage,
        especially when the interval between inspect_tensor_postquantize is large.
        Returning only a bool is deprecated.
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        Parameters
        ----------

        config: Dict
            dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
        layer_name: str
        gemm: str
            one of [`fprop`, `dgrad`, `wgrad`],
        tensor_name: str
            one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
        iteration: int
            iteration number - equal to the number of times `debug_api.step()` was called.

        Returns
        -------

393
        Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
394
        """
395
        return False, None
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414


@Registry.register_namespace_api(namespace="transformer_engine")
class TransformerEngineAPI(BaseNamespaceAPI):
    """
    Transformer Engine API class that contains default APIs that are invoked when a config is not provided
    or a layer is not selected in the config.
    TransformerEngine specific features must override these APIs wherever required.
    The overridden APIs will be invoked whenever the corresponding feature is enabled in the config.
    """

    def __init__(self):
        BaseNamespaceAPI.__init__(self)
        self._default_api_impl = TEDefaultFeatures()
        self._cacheable_api_kwargs_map = {
            "fp8_gemm": ["gemm"],
            "modify_tensor": ["tensor_name", "gemm"],
            "inspect_tensor": ["tensor_name"],
            "inspect_tensor_postquantize": ["tensor_name"],
415
416
            "inspect_tensor_enabled": ["tensor_name", "iteration"],
            "inspect_tensor_postquantize_enabled": ["tensor_name", "iteration"],
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            "modify_tensor_enabled": ["tensor_name"],
        }

    def is_multiple_feature_invocation_allowed(self, api_name):
        """
        Check if API allows executing multiple features for a single call
        """
        return api_name in {
            "fp8_gemm_enabled",
            "inspect_tensor",
            "inspect_tensor_postquantize",
            "inspect_tensor_enabled",
            "inspect_tensor_postquantize_enabled",
        }

    def input_assertions_hook(self, api_name, **kwargs):
        """
        These args must be passed as kwargs in the API call for all TransformerEngine specific APIs.
        """

        if api_name in required_kwargs:
            for kwarg in required_kwargs[api_name]:
                assert kwarg in kwargs, (
                    f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
                    f" {api_name}."
                )
        else:
            for kwarg in required_kwargs["default"]:
                assert kwarg in kwargs, (
                    f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
                    f" {api_name}."
                )

    def routing_condition(self, api_name, config, _, feature_obj, **kwargs):
        """
        Overridden APIs are selected based on the GEMM name in the config and kwargs.
        """
        tensor_parsing = "tensor_name" in required_kwargs[api_name]
        gemm_parsing = "gemm" in required_kwargs[api_name]
        status, modified_config = feature_obj.parse_config_and_api(
            config, gemm_parsing=gemm_parsing, tensor_parsing=tensor_parsing, **kwargs
        )
        return status, modified_config

    def output_assertions_hook(self, api_name, ret, **kwargs):
        """Output hooks used to check correctness of the outputs of the API calls."""
        if "enabled" in api_name or api_name == "fp8_gemm":
464
            assert isinstance(ret, (bool, tuple))
465
466
467
        if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
            assert ret is None
        if api_name == "modify_tensor":
468
            assert type(ret) in get_all_tensor_types()
469
470
471
472
473
474
475
            if (
                type(ret) == torch.Tensor  # pylint: disable=unidiomatic-typecheck
                and "dtype" in kwargs
            ):
                if kwargs["dtype"] is not None:
                    assert ret.dtype == kwargs["dtype"]

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    def call_feature(self, call, feat_config, layer_name, **kwargs):
        """
        For backward compatibility, remove kwargs that are not needed for the call
        """
        if call.__name__ == "inspect_tensor":
            kwargs_copy = kwargs.copy()
            for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]:
                if k not in call.__code__.co_varnames:
                    kwargs_copy.pop(k)
        else:
            kwargs_copy = kwargs

        if call.__name__ == "inspect_tensor_postquantize":
            warnings.warn(
                "inspect_tensor_postquantize is deprecated, use inspect_tensor instead.",
                DeprecationWarning,
            )

        return call(feat_config, layer_name, **kwargs_copy)

496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def handle_multi_feature_output(
        self, api_name, multi_feature_outputs, features_to_invoke, **kwargs
    ):
        """
        Handle multi-tensor output of the API calls.
        """
        if "enabled" in api_name:
            # *_enabled feature calls can return bool, or tuple (bool, Optional[int]).
            # If any of them returns bool, then we return bool - this means that we cannot state anything
            # about enablement in the next steps.
            # If all of them return a tuple (bool, Optional[int]), we return the minimum value,
            # representing the number of steps after the feature will be enabled next time.
            # If the second value is None, that means that the feature will never be enabled.
            all_ret_tuple = all(
510
                isinstance(feature_output, tuple) for feature_output in multi_feature_outputs
511
512
            )
            if all_ret_tuple:
513
                run_current = any(feature_output[0] for feature_output in multi_feature_outputs)
514
                next_iter = None
515
516
517
518
                for feature_output in multi_feature_outputs:
                    if next_iter is None:
                        next_iter = feature_output[1]
                    elif feature_output[1] is not None:
519
520
                        next_iter = min(next_iter, feature_output[1])
                return run_current, next_iter
521
            run_current = any(feature_output for feature_output in multi_feature_outputs)
522
523
524
525
526
            return run_current, None
        return super().handle_multi_feature_output(
            api_name, multi_feature_outputs, features_to_invoke, **kwargs
        )

527
528
529
530
531
532
    def step(self):
        """This function is called by the nvidia-dlframework-inspect after every debug_api.step()"""
        STATS_BUFFERS.log_stats()

    def end_debug(self):
        """This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
533
        TEDebugState._reset()