inspect.cpp 4.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
/*************************************************************************
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/
#include <cuda_runtime.h>

#include <fstream>
#include <iostream>

#include "../extensions.h"
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
                      Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
                      Result_Type output_buf) {
  NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
  NVTE_CHECK(output_buf->untyped_data() != nullptr,
             "Output must be provided for inspect operation");
  NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
             "Input and output must point to the same buffer for inspect operation");

  std::vector<uint8_t> input_data(input_buf.size_bytes());
  NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(),
                                  input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream));

  float min_val{}, max_val{}, mean_val{}, std_val{};
  NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float),
                                  cudaMemcpyDeviceToHost, stream));
  NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float),
                                  cudaMemcpyDeviceToHost, stream));
  NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float),
                                  cudaMemcpyDeviceToHost, stream));
  NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float),
                                  cudaMemcpyDeviceToHost, stream));

  NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));

  int device;
  NVTE_CHECK_CUDA(cudaGetDevice(&device));

  // Write the tensor data to a file as a binary blob
  std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
  std::ofstream file(filename, std::ios::binary);
  NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
  file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
  file.close();

  // Write out a metadata file
  std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
  std::ofstream meta_file(meta_filename);
  NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename);
  meta_file << "{";
  meta_file << "\"shape\": [";
  for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
    meta_file << input_buf.dimensions()[i];
    if (i < input_buf.dimensions().size() - 1) {
      meta_file << ", ";
    }
  }
  meta_file << "], ";
  meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type());
  meta_file << ", \"min\": " << min_val;
  meta_file << ", \"max\": " << max_val;
  meta_file << ", \"mean\": " << mean_val;
  meta_file << ", \"std\": " << std_val;
  meta_file << "}";
  meta_file.close();

  // Log the tensor metadata to the console
  printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
  for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
    printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
    if (i < input_buf.dimensions().size() - 1) {
      printf(", ");
    }
  }
  printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
  printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // min
                                  .Arg<Buffer_Type>()      // max
                                  .Arg<Buffer_Type>()      // mean
                                  .Arg<Buffer_Type>()      // std
                                  .Ret<Buffer_Type>()      // output
);

}  // namespace jax
}  // namespace transformer_engine