SDPAInference.cpp 6.45 KB
Newer Older
yanjl1's avatar
Initial  
yanjl1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <iostream>

#include "utils.hpp"

#include <hipdnn_data_sdk/utilities/Tensor.hpp>
#include <hipdnn_data_sdk/utilities/Workspace.hpp>
#include <hipdnn_frontend.hpp>

int main()
{
    using InputType = hipdnn_data_sdk::types::half;

    const int64_t b = 2; // batch size
    const int64_t headDimQ = 4; // head dim
    const int64_t headDimK = 4; // head dim
    const int64_t headDimV = 4; // head dim
    const int64_t seqLenQ = 64; // q tensor is padded to this seq length
    const int64_t seqLenKV = 64; // k and v tensor is padded to this seq length
    const int64_t dQK = 32; // hidden dim
    const int64_t dV = 32; // hidden dim
    const float attnScale = 1.0f;
    const bool generateStats = false;
    const bool causalMask = false;
    const bool paddingMask = false;
    const bool alibiMask = false;
    const bool hasAttnBias = false;

    auto buildSdpaInferenceGraph = [=](hipdnnHandle_t handle) {
        auto graph = std::make_shared<hipdnn_frontend::graph::Graph>();
        graph->set_name("sdpa_inference_graph")
            .set_io_data_type(hipdnn_frontend::getDataTypeEnumFromType<InputType>())
            .set_intermediate_data_type(hipdnn_frontend::DataType::FLOAT)
            .set_compute_data_type(hipdnn_frontend::DataType::FLOAT);

        auto q = std::make_shared<hipdnn_frontend::graph::TensorAttributes>(
            hipdnn_frontend::graph::Tensor_attributes()
                .set_name("q")
                .set_dim({b, headDimQ, seqLenQ, dQK})
                .set_stride({headDimQ * seqLenQ * dQK, seqLenQ * dQK, dQK, 1}));

        auto k = std::make_shared<hipdnn_frontend::graph::TensorAttributes>(
            hipdnn_frontend::graph::Tensor_attributes()
                .set_name("k")
                .set_dim({b, headDimK, seqLenKV, dQK})
                .set_stride({headDimK * seqLenKV * dQK, seqLenKV * dQK, dQK, 1}));

        auto v = std::make_shared<hipdnn_frontend::graph::TensorAttributes>(
            hipdnn_frontend::graph::Tensor_attributes()
                .set_name("v")
                .set_dim({b, headDimV, seqLenKV, dV})
                .set_stride({headDimV * seqLenKV * dV, seqLenKV * dV, dV, 1}));

        auto bias = std::make_shared<hipdnn_frontend::graph::TensorAttributes>(
            hipdnn_frontend::graph::Tensor_attributes()
                .set_name("bias")
                .set_dim({b, 1, seqLenQ, seqLenKV})
                .set_stride({seqLenQ * seqLenKV, seqLenQ * seqLenKV, seqLenKV, 1}));

        auto seqLengthQ = std::make_shared<hipdnn_frontend::graph::TensorAttributes>(
            hipdnn_frontend::graph::Tensor_attributes()
                .set_name("seq_length_q")
                .set_data_type(hipdnn_frontend::DataType::INT32)
                .set_dim({b, 1, 1, 1})
                .set_stride({1, 1, 1, 1}));

        auto seqLengthKV = std::make_shared<hipdnn_frontend::graph::TensorAttributes>(
            hipdnn_frontend::graph::Tensor_attributes()
                .set_name("seq_length_kv")
                .set_data_type(hipdnn_frontend::DataType::INT32)
                .set_dim({b, 1, 1, 1})
                .set_stride({1, 1, 1, 1}));

        auto sdpaAttributes = hipdnn_frontend::graph::SdpaAttributes()
                                  .set_name("sdpa_inference_node")
                                  .set_generate_stats(generateStats)
                                  .set_alibi_mask(alibiMask)
                                  .set_attn_scale_value(attnScale);
        if(causalMask)
        {
            sdpaAttributes.set_diagonal_alignment(hipdnn_frontend::DiagonalAlignment_t::TOP_LEFT)
                .set_diagonal_band_right_bound(0);
        }
        if(hasAttnBias)
        {
            sdpaAttributes.set_bias(bias);
        }
        if(paddingMask)
        {
            sdpaAttributes.set_padding_mask(paddingMask)
                .set_seq_len_q(seqLengthQ)
                .set_seq_len_kv(seqLengthKV);
        }

        auto [outO, outStats] = graph->sdpa(q, k, v, sdpaAttributes);
        outO->set_output(true);
        if(generateStats)
        {
            outStats->set_output(true).set_data_type(hipdnn_frontend::DataType_t::FLOAT);
        }

        // build graph
        HIPDNN_FE_CHECK(graph->build(handle));

        return std::make_tuple(graph, q, k, v, bias, outStats, seqLengthQ, seqLengthKV, outO);
    };

    auto backend = hipdnn_frontend::detail::hipdnnBackend();
    if(!backend)
    {
        std::cout << "Creat backend failed. \n";
        return 1;
    }

    hipdnnHandle_t handle;
    HIPDNN_CHECK(backend->create(&handle));

    auto [graph, q, k, v, bias, outStats, seqLengthQ, seqLengthKV, outO]
        = buildSdpaInferenceGraph(handle);

    hipdnn_data_sdk::utilities::Tensor<InputType> qTensor(q->get_dim(), q->get_stride());
    hipdnn_data_sdk::utilities::Tensor<InputType> kTensor(k->get_dim(), k->get_stride());
    hipdnn_data_sdk::utilities::Tensor<InputType> vTensor(v->get_dim(), v->get_stride());
    hipdnn_data_sdk::utilities::Tensor<InputType> oTensor(outO->get_dim(), outO->get_stride());
    hipdnn_data_sdk::utilities::Tensor<InputType> biasTensor(bias->get_dim());
    hipdnn_data_sdk::utilities::Tensor<float> outStatsTensor(
        generateStats ? outStats->get_dim() : std::vector<int64_t>{});
    hipdnn_data_sdk::utilities::Tensor<int32_t> seqLengthQTensor(seqLengthQ->get_dim());
    hipdnn_data_sdk::utilities::Tensor<int32_t> seqLengthKVTensor(seqLengthKV->get_dim());

    std::unordered_map<int64_t, void*> variantPack;
    variantPack[q->get_uid()] = qTensor.memory().deviceData();
    variantPack[k->get_uid()] = kTensor.memory().deviceData();
    variantPack[v->get_uid()] = vTensor.memory().deviceData();
    variantPack[outO->get_uid()] = oTensor.memory().deviceData();
    if(hasAttnBias)
    {
        variantPack[bias->get_uid()] = biasTensor.memory().deviceData();
    }
    if(generateStats)
    {
        variantPack[outStats->get_uid()] = outStatsTensor.memory().deviceData();
    }
    if(paddingMask)
    {
        variantPack[seqLengthQ->get_uid()] = seqLengthQTensor.memory().deviceData();
        variantPack[seqLengthKV->get_uid()] = seqLengthKVTensor.memory().deviceData();
    }

    int64_t workspaceSize = 0;
    HIPDNN_FE_CHECK(graph->get_workspace_size(workspaceSize));
    const hipdnn_data_sdk::utilities::Workspace workspace(static_cast<size_t>(workspaceSize));

    HIPDNN_FE_CHECK(graph->execute(handle, variantPack, workspace.get()));

    std::cout << "Sdpa_inference graph execution complete. \n";

    HIPDNN_CHECK(backend->destroy(handle));
    return 0;
}