/************************************************************************* * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include "../extensions.h" #include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf, Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf, Result_Type output_buf) { NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation"); NVTE_CHECK(output_buf->untyped_data() != nullptr, "Output must be provided for inspect operation"); NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(), "Input and output must point to the same buffer for inspect operation"); std::vector input_data(input_buf.size_bytes()); NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(), input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream)); float min_val{}, max_val{}, mean_val{}, std_val{}; NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream)); NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream)); NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream)); NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float), cudaMemcpyDeviceToHost, stream)); NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); int device; NVTE_CHECK_CUDA(cudaGetDevice(&device)); // Write the tensor data to a file as a binary blob std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; std::ofstream file(filename, std::ios::binary); NVTE_CHECK(file.is_open(), "Failed to create file: ", filename); file.write(reinterpret_cast(input_data.data()), input_data.size()); file.close(); // Write out a metadata file std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json"; std::ofstream meta_file(meta_filename); NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename); meta_file << "{"; meta_file << "\"shape\": ["; for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { meta_file << input_buf.dimensions()[i]; if (i < input_buf.dimensions().size() - 1) { meta_file << ", "; } } meta_file << "], "; meta_file << "\"dtype\": " << static_cast(input_buf.element_type()); meta_file << ", \"min\": " << min_val; meta_file << ", \"max\": " << max_val; meta_file << ", \"mean\": " << mean_val; meta_file << ", \"std\": " << std_val; meta_file << "}"; meta_file.close(); // Log the tensor metadata to the console printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str()); for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { printf("%zu", static_cast(input_buf.dimensions()[i])); if (i < input_buf.dimensions().size() - 1) { printf(", "); } } printf("], dtype: %d", static_cast(input_buf.element_type())); printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); return ffi_with_cuda_error_check(); } XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI, FFI::Bind() .Ctx() // stream .Arg() // input .Arg() // min .Arg() // max .Arg() // mean .Arg() // std .Ret() // output ); } // namespace jax } // namespace transformer_engine