misc.h 2.27 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>

#include <cassert>
#include <string>
#include <vector>

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

struct Shape {
  int num_dim;
  size_t dims[kMaxNumDim];

  void from_vector(const std::vector<size_t> &shape);

  std::vector<size_t> to_vector() const;
};

std::vector<size_t> MakeShapeVector(NVTEShape shape);

29
30
31
32
33
34
35
36
inline size_t product(const std::vector<size_t> &shape) {
  size_t ret = 1;
  for (const auto &elem : shape) {
    ret *= elem;
  }
  return ret;
}

37
enum class QuantizeLayout {
38
39
40
41
42
  ROWWISE,
  COLWISE,
  ROWWISE_COLWISE,
};

43
44
45
46
enum class JAXX_Scaling_Mode : int64_t {
  NO_SCALING = 0,
  DELAYED_TENSOR_SCALING = 1,
  MXFP8_1D_SCALING = 2,
47
  CURRENT_TENSOR_SCALING = 3,
48
49
};

Alp Dener's avatar
Alp Dener committed
50
51
52
53
54
55
56
57
58
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
          mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}

inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
  return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}

59
60
61
62
63
64
65
66
67
68
69
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
  switch (mode) {
    case JAXX_Scaling_Mode::NO_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
    case JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
    case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
      return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
      break;
70
71
72
    case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
      return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
      break;
73
74
75
76
77
78
    default:
      NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
      break;
  }
}

79
80
81
82
83
84
85
86
87
88
89
constexpr struct BlockSize {
  size_t x;
  size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
  size_t x;
  size_t y;
} MXFP8_ALIGNMENT{128, 4};

std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);

90
91
}  // namespace jax
}  // namespace transformer_engine