ffi.cpp 1.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "extensions/ffi.h"

#include <iostream>

#include "common/util/logging.h"

namespace transformer_engine {
namespace jax {

// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
  switch (type) {
18
19
20
21
22
23
24
25
    case xla::ffi::DataType::U8:
      return DType::kByte;
      break;
    case xla::ffi::DataType::S32:
      return DType::kInt32;
      break;
    case xla::ffi::DataType::S64:
      return DType::kInt64;
26
27
28
29
      break;
    case xla::ffi::DataType::F32:
      return DType::kFloat32;
      break;
30
31
32
    case xla::ffi::DataType::F16:
      return DType::kFloat16;
      break;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    case xla::ffi::DataType::BF16:
      return DType::kBFloat16;
      break;
    case xla::ffi::DataType::F8E5M2:
      return DType::kFloat8E5M2;
      break;
    case xla::ffi::DataType::F8E4M3FN:
      return DType::kFloat8E4M3;
      break;
    default:
      auto type_num = static_cast<XLA_FFI_DataType>(type);
      NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
                 static_cast<int>(type_num));
      break;
  }
}

Error_Type ffi_with_cuda_error_check() {
  cudaError_t last_error = cudaGetLastError();
  if (last_error != cudaSuccess) {
    return Error_Type(XLA_FFI_Error_Code_INTERNAL,
                      std::string("CUDA error: ") + cudaGetErrorString(last_error));
  }
  return Error_Type::Success();
}

}  // namespace jax
}  // namespace transformer_engine