#include #include "utils.hpp" #include #include #include int main() { using InputType = hipdnn_data_sdk::types::half; const int64_t b = 2; // batch size const int64_t headDimQ = 4; // head dim const int64_t headDimK = 4; // head dim const int64_t headDimV = 4; // head dim const int64_t seqLenQ = 64; // q tensor is padded to this seq length const int64_t seqLenKV = 64; // k and v tensor is padded to this seq length const int64_t dQK = 32; // hidden dim const int64_t dV = 32; // hidden dim const float attnScale = 1.0f; const bool generateStats = false; const bool causalMask = false; const bool paddingMask = false; const bool alibiMask = false; const bool hasAttnBias = false; auto buildSdpaInferenceGraph = [=](hipdnnHandle_t handle) { auto graph = std::make_shared(); graph->set_name("sdpa_inference_graph") .set_io_data_type(hipdnn_frontend::getDataTypeEnumFromType()) .set_intermediate_data_type(hipdnn_frontend::DataType::FLOAT) .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); auto q = std::make_shared( hipdnn_frontend::graph::Tensor_attributes() .set_name("q") .set_dim({b, headDimQ, seqLenQ, dQK}) .set_stride({headDimQ * seqLenQ * dQK, seqLenQ * dQK, dQK, 1})); auto k = std::make_shared( hipdnn_frontend::graph::Tensor_attributes() .set_name("k") .set_dim({b, headDimK, seqLenKV, dQK}) .set_stride({headDimK * seqLenKV * dQK, seqLenKV * dQK, dQK, 1})); auto v = std::make_shared( hipdnn_frontend::graph::Tensor_attributes() .set_name("v") .set_dim({b, headDimV, seqLenKV, dV}) .set_stride({headDimV * seqLenKV * dV, seqLenKV * dV, dV, 1})); auto bias = std::make_shared( hipdnn_frontend::graph::Tensor_attributes() .set_name("bias") .set_dim({b, 1, seqLenQ, seqLenKV}) .set_stride({seqLenQ * seqLenKV, seqLenQ * seqLenKV, seqLenKV, 1})); auto seqLengthQ = std::make_shared( hipdnn_frontend::graph::Tensor_attributes() .set_name("seq_length_q") .set_data_type(hipdnn_frontend::DataType::INT32) .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1})); auto seqLengthKV = std::make_shared( hipdnn_frontend::graph::Tensor_attributes() .set_name("seq_length_kv") .set_data_type(hipdnn_frontend::DataType::INT32) .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1})); auto sdpaAttributes = hipdnn_frontend::graph::SdpaAttributes() .set_name("sdpa_inference_node") .set_generate_stats(generateStats) .set_alibi_mask(alibiMask) .set_attn_scale_value(attnScale); if(causalMask) { sdpaAttributes.set_diagonal_alignment(hipdnn_frontend::DiagonalAlignment_t::TOP_LEFT) .set_diagonal_band_right_bound(0); } if(hasAttnBias) { sdpaAttributes.set_bias(bias); } if(paddingMask) { sdpaAttributes.set_padding_mask(paddingMask) .set_seq_len_q(seqLengthQ) .set_seq_len_kv(seqLengthKV); } auto [outO, outStats] = graph->sdpa(q, k, v, sdpaAttributes); outO->set_output(true); if(generateStats) { outStats->set_output(true).set_data_type(hipdnn_frontend::DataType_t::FLOAT); } // build graph HIPDNN_FE_CHECK(graph->build(handle)); return std::make_tuple(graph, q, k, v, bias, outStats, seqLengthQ, seqLengthKV, outO); }; auto backend = hipdnn_frontend::detail::hipdnnBackend(); if(!backend) { std::cout << "Creat backend failed. \n"; return 1; } hipdnnHandle_t handle; HIPDNN_CHECK(backend->create(&handle)); auto [graph, q, k, v, bias, outStats, seqLengthQ, seqLengthKV, outO] = buildSdpaInferenceGraph(handle); hipdnn_data_sdk::utilities::Tensor qTensor(q->get_dim(), q->get_stride()); hipdnn_data_sdk::utilities::Tensor kTensor(k->get_dim(), k->get_stride()); hipdnn_data_sdk::utilities::Tensor vTensor(v->get_dim(), v->get_stride()); hipdnn_data_sdk::utilities::Tensor oTensor(outO->get_dim(), outO->get_stride()); hipdnn_data_sdk::utilities::Tensor biasTensor(bias->get_dim()); hipdnn_data_sdk::utilities::Tensor outStatsTensor( generateStats ? outStats->get_dim() : std::vector{}); hipdnn_data_sdk::utilities::Tensor seqLengthQTensor(seqLengthQ->get_dim()); hipdnn_data_sdk::utilities::Tensor seqLengthKVTensor(seqLengthKV->get_dim()); std::unordered_map variantPack; variantPack[q->get_uid()] = qTensor.memory().deviceData(); variantPack[k->get_uid()] = kTensor.memory().deviceData(); variantPack[v->get_uid()] = vTensor.memory().deviceData(); variantPack[outO->get_uid()] = oTensor.memory().deviceData(); if(hasAttnBias) { variantPack[bias->get_uid()] = biasTensor.memory().deviceData(); } if(generateStats) { variantPack[outStats->get_uid()] = outStatsTensor.memory().deviceData(); } if(paddingMask) { variantPack[seqLengthQ->get_uid()] = seqLengthQTensor.memory().deviceData(); variantPack[seqLengthKV->get_uid()] = seqLengthKVTensor.memory().deviceData(); } int64_t workspaceSize = 0; HIPDNN_FE_CHECK(graph->get_workspace_size(workspaceSize)); const hipdnn_data_sdk::utilities::Workspace workspace(static_cast(workspaceSize)); HIPDNN_FE_CHECK(graph->execute(handle, variantPack, workspace.get())); std::cout << "Sdpa_inference graph execution complete. \n"; HIPDNN_CHECK(backend->destroy(handle)); return 0; }