Unverified Commit 7e48fa1b authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Debugging inspect utility (#2651)



* initial debug of inspect ffi
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* writing binary dumps of tensors works
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* loading works
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add tensor statistics
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add cuda error check and tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Ad __init__.py to debug folder
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Address greptile comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Gate tests behind fp8 support
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent fa68781c
......@@ -1921,3 +1921,37 @@ class TestGroupedDense:
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
class TestDebugInspectFFI:
@pytest_parametrize_wrapper("shape", [(256, 128)])
@pytest_parametrize_wrapper(
"dtype",
[
jnp.float32,
jnp.bfloat16,
jnp.float16,
# Note: fp4 currently doesn't work
# jnp.float4_e2m1fn
]
+ ([jnp.float8_e4m3fn, jnp.float8_e5m2] if is_fp8_supported else []),
)
def test_debug_inspect_ffi(self, shape, dtype):
from transformer_engine.jax.debug.experimental import inspect_array, load_array_dump
def f(x):
x = x + 1
x = inspect_array(x, "my_array")
x = x + 1
return x
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32)
x = x.astype(dtype)
_ = jax.jit(f)(x)
expected = x + 1
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)
assert_allclose(actual, expected, dtype=dtype)
......@@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);
// Inspect
XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
......
......@@ -5,8 +5,6 @@
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <fstream>
#include <iostream>
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
Result_Type output_buf) {
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");
std::vector<uint8_t> input_data(input_buf.size_bytes());
NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(),
input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream));
float min_val{}, max_val{}, mean_val{}, std_val{};
NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
int device;
NVTE_CHECK_CUDA(cudaGetDevice(&device));
// Write the tensor data to a file as a binary blob
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
std::ofstream file(filename, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
// Write out a metadata file
std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
std::ofstream meta_file(meta_filename);
NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename);
meta_file << "{";
meta_file << "\"shape\": [";
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
meta_file << input_buf.dimensions()[i];
if (i < input_buf.dimensions().size() - 1) {
meta_file << ", ";
}
}
meta_file << "], ";
meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type());
meta_file << ", \"min\": " << min_val;
meta_file << ", \"max\": " << max_val;
meta_file << ", \"mean\": " << mean_val;
meta_file << ", \"std\": " << std_val;
meta_file << "}";
meta_file.close();
// Log the tensor metadata to the console
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // min
.Arg<Buffer_Type>() // max
.Arg<Buffer_Type>() // mean
.Arg<Buffer_Type>() // std
.Ret<Buffer_Type>() // output
);
} // namespace jax
} // namespace transformer_engine
......@@ -81,6 +81,9 @@ pybind11::dict Registrations() {
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
dict["te_inspect_ffi"] =
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));
return dict;
}
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""
__all__ = [
"experimental",
]
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""
from .inspect import inspect_array, load_array_dump
__all__ = [
"inspect_array",
"load_array_dump",
]
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental JAX array inspection utilities."""
from functools import partial
import jax
import jax.numpy as jnp
from jax import ffi
from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
__all__ = ["inspect_array", "load_array_dump"]
class InspectPrimitive(BasePrimitive):
"""
No-op used for inspect array values.
"""
name = "te_inspect_ffi"
multiple_results = False
impl_static_args = ()
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
x_min_aval,
x_max_aval,
x_mean_aval,
x_std_aval,
):
"""
inspect abstract
"""
assert (
x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
), "x_min must be a scalar with dtype float32"
assert (
x_max_aval.shape == () and x_max_aval.dtype == jnp.float32
), "x_max must be a scalar with dtype float32"
assert (
x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32
), "x_mean must be a scalar with dtype float32"
assert (
x_std_aval.shape == () and x_std_aval.dtype == jnp.float32
), "x_std must be a scalar with dtype float32"
return x_aval
@staticmethod
def lowering(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect lowering rules
"""
return ffi.ffi_lowering(
InspectPrimitive.name,
operand_output_aliases={0: 0}, # donate input buffer to output buffer
)(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
)
@staticmethod
def impl(
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect implementation
"""
assert InspectPrimitive.inner_primitive is not None
(x) = InspectPrimitive.inner_primitive.bind(
x,
x_min,
x_max,
x_mean,
x_std,
)
return x
register_primitive(InspectPrimitive)
def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
assert InspectPrimitive.outer_primitive is not None, (
"InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
" and registered."
)
return InspectPrimitive.outer_primitive.bind(
x,
jnp.min(x).astype(jnp.float32),
jnp.max(x).astype(jnp.float32),
jnp.mean(x.astype(jnp.float32)),
jnp.std(x.astype(jnp.float32)),
)
@partial(jax.custom_vjp, nondiff_argnums=())
def _inspect(
x,
):
""" """
output, _ = _inspect_fwd_rule(
x,
)
return output
def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = _inspect_array_inner(x)
return x, ctx
def _inspect_bwd_rule(
ctx,
grad,
):
""""""
del ctx
return (grad,)
_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.
Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
del name # Name is currently unused, but can be included in the future for more informative output
return _inspect(x)
def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
"""Utility function to load a JAX array from a dumped binary file.
Args:
filename (str): The path to the binary file containing the array data.
shape (tuple): The shape of the array to be loaded.
dtype (jnp.dtype): The data type of the array to be loaded.
Returns:
jnp.ndarray: The loaded JAX array.
"""
with open(filename, "rb") as f:
data = f.read()
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
return array
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment