inspect.py 4.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental JAX array inspection utilities."""

from functools import partial

import jax
import jax.numpy as jnp
from jax import ffi

from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive

__all__ = ["inspect_array", "load_array_dump"]


class InspectPrimitive(BasePrimitive):
    """
    No-op used for inspect array values.
    """

    name = "te_inspect_ffi"
    multiple_results = False
    impl_static_args = ()
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        x_min_aval,
        x_max_aval,
        x_mean_aval,
        x_std_aval,
    ):
        """
        inspect abstract
        """
        assert (
            x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
        ), "x_min must be a scalar with dtype float32"
        assert (
            x_max_aval.shape == () and x_max_aval.dtype == jnp.float32
        ), "x_max must be a scalar with dtype float32"
        assert (
            x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32
        ), "x_mean must be a scalar with dtype float32"
        assert (
            x_std_aval.shape == () and x_std_aval.dtype == jnp.float32
        ), "x_std must be a scalar with dtype float32"
        return x_aval

    @staticmethod
    def lowering(
        ctx,
        x,
        x_min,
        x_max,
        x_mean,
        x_std,
    ):
        """
        inspect lowering rules
        """

        return ffi.ffi_lowering(
            InspectPrimitive.name,
            operand_output_aliases={0: 0},  # donate input buffer to output buffer
        )(
            ctx,
            x,
            x_min,
            x_max,
            x_mean,
            x_std,
        )

    @staticmethod
    def impl(
        x,
        x_min,
        x_max,
        x_mean,
        x_std,
    ):
        """
        inspect implementation
        """
        assert InspectPrimitive.inner_primitive is not None
        (x) = InspectPrimitive.inner_primitive.bind(
            x,
            x_min,
            x_max,
            x_mean,
            x_std,
        )
        return x


register_primitive(InspectPrimitive)


def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
    assert InspectPrimitive.outer_primitive is not None, (
        "InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
        " and registered."
    )
    return InspectPrimitive.outer_primitive.bind(
        x,
        jnp.min(x).astype(jnp.float32),
        jnp.max(x).astype(jnp.float32),
        jnp.mean(x.astype(jnp.float32)),
        jnp.std(x.astype(jnp.float32)),
    )


@partial(jax.custom_vjp, nondiff_argnums=())
def _inspect(
    x,
):
    """ """
    output, _ = _inspect_fwd_rule(
        x,
    )
    return output


def _inspect_fwd_rule(
    x,
):
    """"""
    ctx = ()
    x = _inspect_array_inner(x)
    return x, ctx


def _inspect_bwd_rule(
    ctx,
    grad,
):
    """"""
    del ctx
    return (grad,)


_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)


def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
    """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

    Args:
        x (jnp.ndarray): The JAX array to inspect.
        name (str): The name of the array for identification in the output.
    """
    del name  # Name is currently unused, but can be included in the future for more informative output
    return _inspect(x)


def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
    """Utility function to load a JAX array from a dumped binary file.

    Args:
        filename (str): The path to the binary file containing the array data.
        shape (tuple): The shape of the array to be loaded.
        dtype (jnp.dtype): The data type of the array to be loaded.

    Returns:
        jnp.ndarray: The loaded JAX array.
    """
    with open(filename, "rb") as f:
        data = f.read()
    array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
    return array