ffi.cpp 1.96 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
 *
 * See LICENSE for license information.
 ************************************************************************/
#include "extensions/ffi.h"

#include <iostream>

namespace transformer_engine {
namespace jax {

// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
  switch (type) {
16
    // Using this for E8M0
17
    case xla::ffi::DataType::U8:
18
      return DType::kFloat8E8M0;
19
20
21
22
23
24
      break;
    case xla::ffi::DataType::S32:
      return DType::kInt32;
      break;
    case xla::ffi::DataType::S64:
      return DType::kInt64;
25
26
27
28
      break;
    case xla::ffi::DataType::F32:
      return DType::kFloat32;
      break;
29
30
31
    case xla::ffi::DataType::F16:
      return DType::kFloat16;
      break;
32
33
34
35
36
37
38
39
40
    case xla::ffi::DataType::BF16:
      return DType::kBFloat16;
      break;
    case xla::ffi::DataType::F8E5M2:
      return DType::kFloat8E5M2;
      break;
    case xla::ffi::DataType::F8E4M3FN:
      return DType::kFloat8E4M3;
      break;
Alp Dener's avatar
Alp Dener committed
41
42
43
    case xla::ffi::DataType::F8E8M0FNU:
      return DType::kFloat8E8M0;
      break;
44
45
46
    case xla::ffi::DataType::F4E2M1FN:
      return DType::kFloat4E2M1;
      break;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    default:
      auto type_num = static_cast<XLA_FFI_DataType>(type);
      NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
                 static_cast<int>(type_num));
      break;
  }
}

Error_Type ffi_with_cuda_error_check() {
  cudaError_t last_error = cudaGetLastError();
  if (last_error != cudaSuccess) {
    return Error_Type(XLA_FFI_Error_Code_INTERNAL,
                      std::string("CUDA error: ") + cudaGetErrorString(last_error));
  }
  return Error_Type::Success();
}

}  // namespace jax
}  // namespace transformer_engine