common.cpp 3.58 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"

namespace transformer_engine {
namespace paddle_ext {

Shijie's avatar
Shijie committed
12
13
14
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
                             const DType type) {
    return TensorWrapper(const_cast<void *>(data_ptr), shape, type);
15
16
}

17
18
19
20
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) {
    return TensorWrapper(data_ptr, shape, type);
}

21
22
23
24
25
26
27
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
                             void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) {
    return TensorWrapper(data_ptr, shape, type, reinterpret_cast<float *>(amax_ptr),
                         reinterpret_cast<float *>(scale_ptr),
                         reinterpret_cast<float *>(scale_inv_ptr));
}

28
29
30
31
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) {  // NOLINT
    return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype()));
}

32
33
34
35
36
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) {
    return MakeNvteTensor(const_cast<void *>(tensor.data()), GetShapeArray(tensor),
                          Paddle2NvteDType(tensor.dtype()));
}

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
                             bool init_to_zeros) {
    auto size = shape.ndim;
    if (size == 2 && init_to_zeros) {
        return paddle::zeros(
            {static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
            Nvte2PaddleDType(type), place);
    } else if (size == 2) {
        return paddle::empty(
            {static_cast<int64_t>(shape.data[0]), static_cast<int64_t>(shape.data[1])},
            Nvte2PaddleDType(type), place);
    } else if (size == 1 && init_to_zeros) {
        return paddle::zeros({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
    } else if (size == 1) {
        return paddle::empty({static_cast<int64_t>(shape.data[0])}, Nvte2PaddleDType(type), place);
    }
    NVTE_CHECK(false, "Should never reach here! func: AllocateSpace");
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// MHA utils
// convert QKV layout to enum
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) {
    static const std::unordered_map<std::string, NVTE_QKV_Layout> layout_map = {
        {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD},
        {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D},
        {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD},
        {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D},
        {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD},
        {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD},
        {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D},
        {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD},
        {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D},
        {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD},
        {"t3hd", NVTE_QKV_Layout::NVTE_T3HD},
        {"th3d", NVTE_QKV_Layout::NVTE_TH3D},
        {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD},
        {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D},
        {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD},
    };

    auto it = layout_map.find(qkv_layout);
    if (it != layout_map.end()) {
        return it->second;
    } else {
        NVTE_ERROR("Invalid QKV layout string: " + qkv_layout);
    }
}

85
86
}  // namespace paddle_ext
}  // namespace transformer_engine