misc.h 4.15 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>

#include <cassert>
#include <string>
#include <vector>

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

struct Shape {
  int num_dim;
  size_t dims[kMaxNumDim];

  void from_vector(const std::vector<size_t> &shape);

  std::vector<size_t> to_vector() const;
};

std::vector<size_t> MakeShapeVector(NVTEShape shape);

29
30
31
32
33
34
35
36
inline size_t product(const std::vector<size_t> &shape) {
  size_t ret = 1;
  for (const auto &elem : shape) {
    ret *= elem;
  }
  return ret;
}

37
enum class JAXX_Quantize_Layout : int64_t {
38
39
40
41
42
  ROWWISE,
  COLWISE,
  ROWWISE_COLWISE,
};

43
44
45
46
47
48
49
50
51
52
53
54
inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) {
  return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}

inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) {
  return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}

inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) {
  return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE;
}

55
56
57
58
enum class JAXX_Scaling_Mode : int64_t {
  NO_SCALING = 0,
  DELAYED_TENSOR_SCALING = 1,
  MXFP8_1D_SCALING = 2,
59
  CURRENT_TENSOR_SCALING = 3,
60
61
  NVFP4_1D_SCALING = 4,
  NVFP4_2D_SCALING = 5,
62
63
};

Alp Dener's avatar
Alp Dener committed
64
65
66
67
68
69
70
71
72
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
          mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}

inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}

73
74
75
76
77
inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING ||
          mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING);
}

78
79
80
81
82
83
84
85
86
87
88
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
  switch (mode) {
    case JAXX_Scaling_Mode::NO_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
    case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
    case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
      return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
      break;
89
90
91
    case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
92
93
94
95
96
97
98
    case JAXX_Scaling_Mode::NVFP4_1D_SCALING:
      return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
      break;
    case JAXX_Scaling_Mode::NVFP4_2D_SCALING:
      // TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization
      return NVTEScalingMode::NVTE_NVFP4_1D_SCALING;
      break;
99
100
101
102
103
104
    default:
      NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
      break;
  }
}

105
struct BLOCK_SIZE {
106
107
  size_t x;
  size_t y;
108
109
110
111
112
113
114
  constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {}
};

constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32};
constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16};

constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4};
115

116
117
std::vector<size_t> get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N,
                                          bool is_colwise);
118

Phuong Nguyen's avatar
Phuong Nguyen committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
template <typename T, typename... Rest>
void hash_combine(int64_t &seed, const T &v, Rest... rest) {
  seed ^= std::hash<T>{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
  (hash_combine(seed, rest), ...);
}

enum class JAXX_Collective_Op : int64_t {
  NONE = 0,
  ALL_GATHER = 1,
  REDUCE_SCATTER = 2,
};

static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) {
  switch (op) {
    case JAXX_Collective_Op::ALL_GATHER:
      return CommOverlapType::AG;
      break;
    case JAXX_Collective_Op::REDUCE_SCATTER:
      return CommOverlapType::RS;
      break;
    default:
      NVTE_ERROR("Invalid Collective Op ", static_cast<int>(op));
      break;
  }
}

145
146
}  // namespace jax
}  // namespace transformer_engine