common.cpp 3.49 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"

namespace transformer_engine {
namespace paddle_ext {

Shijie's avatar
Shijie committed
12
13
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
                             const DType type) {
14
  return TensorWrapper(const_cast<void *>(data_ptr), shape, type);
15
16
}

17
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
18
  return TensorWrapper(data_ptr, shape, type);
19
20
}

21
22
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
                             void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) {
23
24
25
  return TensorWrapper(data_ptr, shape, type, reinterpret_cast<float *>(amax_ptr),
                       reinterpret_cast<float *>(scale_ptr),
                       reinterpret_cast<float *>(scale_inv_ptr));
26
27
}

28
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) {  // NOLINT
29
  return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype()));
30
31
}

32
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) {
33
34
  return MakeNvteTensor(const_cast<void *>(tensor.data()), GetShapeArray(tensor),
                        Paddle2NvteDType(tensor.dtype()));
35
36
}

37
38
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
                             bool init_to_zeros) {
39
40
41
42
43
44
45
46
47
48
49
50
51
  auto size = shape.ndim;
  if (size == 2 && init_to_zeros) {
    return paddle::zeros({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
                         Nvte2PaddleDType(type), place);
  } else if (size == 2) {
    return paddle::empty({static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
                         Nvte2PaddleDType(type), place);
  } else if (size == 1 && init_to_zeros) {
    return paddle::zeros({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
  } else if (size == 1) {
    return paddle::empty({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
  }
  NVTE_CHECK(false, "Should never reach here! func: AllocateSpace");
52
53
}

54
55
56
// MHA utils
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) {
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  static const std::unordered_map<std::string, NVTE_QKV_Layout> layout_map = {
      {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD},
      {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D},
      {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD},
      {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D},
      {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD},
      {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD},
      {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D},
      {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD},
      {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D},
      {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD},
      {"t3hd", NVTE_QKV_Layout::NVTE_T3HD},
      {"th3d", NVTE_QKV_Layout::NVTE_TH3D},
      {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD},
      {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D},
      {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD},
  };
74

75
76
77
78
79
80
  auto it = layout_map.find(qkv_layout);
  if (it != layout_map.end()) {
    return it->second;
  } else {
    NVTE_ERROR("Invalid QKV layout string: " + qkv_layout);
  }
81
82
}

83
84
}  // namespace paddle_ext
}  // namespace transformer_engine