misc.h 2.93 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>

#include <cassert>
#include <string>
#include <vector>

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

struct Shape {
  int num_dim;
  size_t dims[kMaxNumDim];

  void from_vector(const std::vector<size_t> &shape);

  std::vector<size_t> to_vector() const;
};

std::vector<size_t> MakeShapeVector(NVTEShape shape);

29
30
31
32
33
34
35
36
inline size_t product(const std::vector<size_t> &shape) {
  size_t ret = 1;
  for (const auto &elem : shape) {
    ret *= elem;
  }
  return ret;
}

37
enum class QuantizeLayout {
38
39
40
41
42
  ROWWISE,
  COLWISE,
  ROWWISE_COLWISE,
};

43
44
45
46
enum class JAXX_Scaling_Mode : int64_t {
  NO_SCALING = 0,
  DELAYED_TENSOR_SCALING = 1,
  MXFP8_1D_SCALING = 2,
47
  CURRENT_TENSOR_SCALING = 3,
48
49
};

Alp Dener's avatar
Alp Dener committed
50
51
52
53
54
55
56
57
58
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
          mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}

inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}

59
60
61
62
63
64
65
66
67
68
69
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
  switch (mode) {
    case JAXX_Scaling_Mode::NO_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
    case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
    case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
      return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
      break;
70
71
72
    case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
73
74
75
76
77
78
    default:
      NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
      break;
  }
}

79
80
81
82
83
84
85
86
87
88
89
constexpr struct BlockSize {
  size_t x;
  size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
  size_t x;
  size_t y;
} MXFP8_ALIGNMENT{128, 4};

std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);

Phuong Nguyen's avatar
Phuong Nguyen committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
  seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
  (hash_combine(seed, rest), ...);
}

enum class JAXX_Collective_Op : int64_t {
  NONE = 0,
  ALL_GATHER = 1,
  REDUCE_SCATTER = 2,
};

static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
  switch (op) {
    case JAXX_Collective_Op::ALL_GATHER:
      return CommOverlapType::AG;
      break;
    case JAXX_Collective_Op::REDUCE_SCATTER:
      return CommOverlapType::RS;
      break;
    default:
      NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
      break;
  }
}

116
117
}  // namespace jax
}  // namespace transformer_engine