// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #ifdef USE_ROCM #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" #undef ORT_API_MANUAL_INIT #include "core/providers/rocm/rocm_context.h" #include "onnxruntime_lite_custom_op.h" //Concat void rocm_concat(int axis, int M1, int N1, const float* X1, int M2, int N2, const float* X2, float* Z, hipStream_t stream); //Gemm void rocm_gemm(bool transA, bool transB, int M, int N, int K, float alpha, const float* A, const float* B, float beta, float* C, hipStream_t stream); extern "C"{ //LeakyRelu void rocm_leaky_relu( int64_t size, const float* d_X, float* d_Y, float alpha, hipStream_t stream); //Attention void rocm_attention(int B, int S, int H, const float* Q, const float* K, const float* V, float* Out, hipStream_t stream); //BatchNormalization void rocm_batch_norm(int64_t N, int64_t C, int64_t H, int64_t W, const float* X, const float* gamma, const float* beta, const float* mean, const float* var, float epsilon, float* Y, hipStream_t stream); //Cast void rocm_cast( int64_t N, // batch size int64_t C, // channels (或其它第一维) int64_t H, // 高度(或第二维) int64_t W, // 宽度(或第三维) const float* X, // 输入指针 int32_t* Y, // 输出指针 hipStream_t stream); //Softmax void rocm_softmax(int64_t M, int64_t N, const float* X, float* Y, hipStream_t compute_stream); //Celu void rocm_celu(int64_t, const float*, float*, float, hipStream_t); //Relu void rocm_relu( int64_t size, const float* X, float* Y, hipStream_t stream ); // Conv void rocm_conv2d(const float* input, const float* weight, const float* bias, float* output, int N, int C_in, int H, int W, int C_out, int K_h, int K_w, int out_H, int out_W, hipStream_t stream); } using namespace Ort::Custom; #define CUSTOM_ENFORCE(cond, msg) \ if (!(cond)) { \ throw std::runtime_error(msg); \ } namespace Rocm { void rocm_leaky_relu_forward( const RocmContext& ctx, const Tensor& X, Tensor& Y) { CUSTOM_ENFORCE(ctx.hip_stream, "No HIP stream available"); int64_t size = X.NumberOfElement(); const float alpha = 0.01f; auto* y_ptr = Y.Allocate(X.Shape()); rocm_leaky_relu(size, X.Data(), y_ptr, alpha, ctx.hip_stream); } void rocm_relu_forward( const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, Ort::Custom::Tensor& Y ) { CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); auto input_shape = X.Shape(); int64_t size = X.NumberOfElement(); auto* y_data = Y.Allocate(input_shape); rocm_relu(size, X.Data(), y_data, rocm_ctx.hip_stream); } void rocm_celu_forward(const Ort::Custom::RocmContext& ctx, const Ort::Custom::Tensor& X, Ort::Custom::Tensor& Y) { CUSTOM_ENFORCE(ctx.hip_stream, "failed to fetch hip stream"); auto shape = X.Shape(); int64_t size = X.NumberOfElement(); float alpha = 1.0f; // or fetch from attribute auto* y_ptr = Y.Allocate(shape); rocm_celu(size, X.Data(), y_ptr, alpha, ctx.hip_stream); } /* softmax */ void KernelSoftmax(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, Ort::Custom::Tensor& Z) { auto input_shape = X.Shape(); CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); // Expecting 2D input: [M, N] CUSTOM_ENFORCE(input_shape.size() == 2, "Softmax only supports 2D input"); int64_t M = static_cast(input_shape[0]); int64_t N = static_cast(input_shape[1]); auto z_raw = Z.Allocate(input_shape); // Call ROCm implementation rocm_softmax(M, N, X.Data(), z_raw, rocm_ctx.hip_stream); } void rocm_cast_forward( const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, Ort::Custom::Tensor& Y) { CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); // 假设只支持 4D 张量 [N,C,H,W] auto shape = X.Shape(); CUSTOM_ENFORCE(shape.size() == 4, "Cast only supports 4D tensor [N,C,H,W]"); int64_t N = shape[0]; int64_t C = shape[1]; int64_t H = shape[2]; int64_t W = shape[3]; // 分配输出 auto* y_ptr = Y.Allocate({N, C, H, W}); // 正确调用:7 个参数 rocm_cast( N, C, H, W, X.Data(), y_ptr, rocm_ctx.hip_stream); } // BatchNormalization void rocm_batchnorm_forward(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, const Ort::Custom::Tensor& scale, const Ort::Custom::Tensor& B, const Ort::Custom::Tensor& mean, const Ort::Custom::Tensor& var, Ort::Custom::Tensor& Y) { CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); auto shape = X.Shape(); // expects [N, C, H, W] CUSTOM_ENFORCE(shape.size() == 4, "Input must be 4D tensor [N, C, H, W]"); int64_t N = shape[0]; int64_t C = shape[1]; int64_t H = shape[2]; int64_t W = shape[3]; // Allocate output auto* y_ptr = Y.Allocate({N, C, H, W}); // Epsilon attribute: retrieve via custom API or hardcode default float epsilon = 1e-5f; // If epsilon comes from attribute, fetch it here. rocm_batch_norm(N, C, H, W, X.Data(), scale.Data(), B.Data(), mean.Data(), var.Data(), epsilon, y_ptr, rocm_ctx.hip_stream); } // attention void rocm_attention_forward(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& Q, const Ort::Custom::Tensor& K, const Ort::Custom::Tensor& V, Ort::Custom::Tensor& Out) { CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); auto shape = Q.Shape(); // 期望为 [B, S, H] CUSTOM_ENFORCE(shape.size() == 3, "Input must be 3D tensor [B, S, H]"); int B = shape[0]; int S = shape[1]; int H = shape[2]; auto* out_ptr = Out.Allocate({B, S, H}); rocm_attention(B, S, H, Q.Data(), K.Data(), V.Data(), out_ptr, rocm_ctx.hip_stream); } // ------------------------------- // Concat // ------------------------------- void rocm_concat_forward(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X1, const Ort::Custom::Tensor& X2, Ort::Custom::Tensor& Y) { CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); auto shape1 = X1.Shape(); auto shape2 = X2.Shape(); // 支持二维张量按列连接(axis=1) CUSTOM_ENFORCE(shape1.size() == 2 && shape2.size() == 2, "Inputs must be 2D tensors."); CUSTOM_ENFORCE(shape1[0] == shape2[0], "Row dimensions must match for concat on axis 1."); int M1 = shape1[0], N1 = shape1[1]; int M2 = shape2[0], N2 = shape2[1]; auto* y_data = Y.Allocate({M1, N1 + N2}); // 输出是合并后的矩阵 rocm_concat(1, M1, N1, X1.Data(), M2, N2, X2.Data(), y_data, rocm_ctx.hip_stream); } /******conv******/ void rocm_conv_forward(const RocmContext& ctx, const Tensor& input, const Tensor& weight, const Tensor& bias, Tensor& output) { CUSTOM_ENFORCE(ctx.hip_stream, "no HIP stream"); const auto& input_shape = input.Shape(); // [N, C_in, H, W] const auto& weight_shape = weight.Shape(); // [C_out, C_in, K_h, K_w] const int64_t N = input_shape[0]; const int64_t C_in = input_shape[1]; const int64_t H = input_shape[2]; const int64_t W = input_shape[3]; const int64_t C_out = weight_shape[0]; const int64_t K_h = weight_shape[2]; const int64_t K_w = weight_shape[3]; const int64_t out_H = (H - K_h) / 1 + 1; const int64_t out_W = (W - K_w) / 1 + 1; auto* y_ptr = output.Allocate({N, C_out, out_H, out_W}); rocm_conv2d(input.Data(), weight.Data(), bias.Data(), y_ptr, N, C_in, H, W, C_out, K_h, K_w, out_H, out_W, ctx.hip_stream); } // ------------------------------- // Gemm // ------------------------------- void rocm_gemm_forward(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& A, const Ort::Custom::Tensor& B, const Ort::Custom::Tensor& C, Ort::Custom::Tensor& Y) { CUSTOM_ENFORCE(rocm_ctx.hip_stream, "failed to fetch hip stream"); auto shapeA = A.Shape(); auto shapeB = B.Shape(); auto shapeC = C.Shape(); CUSTOM_ENFORCE(shapeA.size() == 2 && shapeB.size() == 2 && shapeC.size() == 2, "Inputs must be 2D tensors."); int M = shapeA[0]; int K = shapeA[1]; int N = shapeB[1]; CUSTOM_ENFORCE(shapeB[0] == K, "Inner dimension mismatch between A and B in Gemm."); CUSTOM_ENFORCE(shapeC[0] == M && shapeC[1] == N, "Output tensor shape mismatch in Gemm."); auto* y_data = Y.Allocate({M, N}); rocm_gemm(false, false, M, N, K, 1.0f, A.Data(), B.Data(), 1.0f, y_data, rocm_ctx.hip_stream); } void RegisterOps(Ort::CustomOpDomain& domain) { //注册 Attention 算子 static const std::unique_ptr c_CustomOpAttention{Ort::Custom::CreateLiteCustomOp("Attention", "ROCMExecutionProvider", rocm_attention_forward)}; domain.Add(c_CustomOpAttention.get()); // 注册 BatchNormalization 算子 static const std::unique_ptr c_CustomOpBatchNorm{Ort::Custom::CreateLiteCustomOp("BatchNormalization", "ROCMExecutionProvider", rocm_batchnorm_forward)}; domain.Add(c_CustomOpBatchNorm.get()); // 注册 Concat 算子 static const std::unique_ptr c_CustomOpConcat{Ort::Custom::CreateLiteCustomOp("Concat", "ROCMExecutionProvider", rocm_concat_forward)}; domain.Add(c_CustomOpConcat.get()); // 注册 Gemm 算子 static const std::unique_ptr c_CustomOpGemm{Ort::Custom::CreateLiteCustomOp("Gemm", "ROCMExecutionProvider", rocm_gemm_forward)}; domain.Add(c_CustomOpGemm.get()); // 注册 Cast 算子 static const std::unique_ptr c_CustomOpCast{Ort::Custom::CreateLiteCustomOp("Cast", "ROCMExecutionProvider", rocm_cast_forward)}; domain.Add(c_CustomOpCast.get()); // 注册 Softmax 算子 static const std::unique_ptr c_CustomOpSoftmax{Ort::Custom::CreateLiteCustomOp("Softmax","ROCMExecutionProvider", KernelSoftmax)}; domain.Add(c_CustomOpSoftmax.get()); // 注册 Celu 算子 static const std::unique_ptr c_CeluOp{Ort::Custom::CreateLiteCustomOp("Celu", "ROCMExecutionProvider", rocm_celu_forward)}; domain.Add(c_CeluOp.get()); // 注册 ReLU 算子 static const std::unique_ptr c_CustomOpRelu{ Ort::Custom::CreateLiteCustomOp("Relu", "ROCMExecutionProvider", rocm_relu_forward)}; domain.Add(c_CustomOpRelu.get()); // 注册LeakyRelu算子 static const std::unique_ptr c_LeakyReLU{ Ort::Custom::CreateLiteCustomOp( "LeakyRelu", "ROCMExecutionProvider", rocm_leaky_relu_forward)}; domain.Add(c_LeakyReLU.get()); //注册conv算子 static const std::unique_ptr c_Conv{ Ort::Custom::CreateLiteCustomOp("Conv", "ROCMExecutionProvider", rocm_conv_forward)}; domain.Add(c_Conv.get()); } } // namespace Rocm #endif