Commit 546b4279 authored by limm's avatar limm
Browse files

add csrc and mmdeploy module

parent 502f4fb9
Pipeline #2810 canceled with stages
// Copyright (c) OpenMMLab. All rights reserved.
#include "pose_tracker/tracking_filter.h"
namespace mmdeploy::mmpose::_pose_tracker {
float get_mean_scale(float scale_w, float scale_h) { return std::sqrt(scale_w * scale_h); }
TrackingFilter::TrackingFilter(const Bbox& bbox, const vector<Point>& kpts,
float std_weight_position, float std_weight_velocity)
: std_weight_position_(std_weight_position), std_weight_velocity_(std_weight_velocity) {
auto center = get_center(bbox);
auto scale = get_scale(bbox);
auto mean_scale = get_mean_scale(scale[0], scale[1]);
const auto n = kpts.size();
pt_filters_.resize(n);
for (int i = 0; i < n; ++i) {
auto& f = pt_filters_[i];
f.init(4, 2);
SetKeyPointTransitionMat(i);
SetKeyPointMeasurementMat(i);
ResetKeyPoint(i, kpts[i], mean_scale);
}
{
// [x, y, w, h, dx, dy, dw, dh]
auto& f = bbox_filter_;
f.init(8, 4);
SetBboxTransitionMat();
SetBboxMeasurementMat();
SetBboxErrorCov(2 * std_weight_position * mean_scale, //
10 * std_weight_velocity * mean_scale);
f.statePost.at<float>(0) = center.x;
f.statePost.at<float>(1) = center.y;
f.statePost.at<float>(2) = scale[0];
f.statePost.at<float>(3) = scale[1];
}
}
std::pair<Bbox, Points> TrackingFilter::Predict() {
auto mean_scale = get_mean_scale(bbox_filter_.statePost.at<float>(2), //
bbox_filter_.statePost.at<float>(3));
const auto n = pt_filters_.size();
Points pts(n);
for (int i = 0; i < n; ++i) {
SetKeyPointProcessCov(i, std_weight_position_ * mean_scale, std_weight_velocity_ * mean_scale);
auto mat = pt_filters_[i].predict();
pts[i].x = mat.at<float>(0);
pts[i].y = mat.at<float>(1);
}
Bbox bbox;
{
SetBboxProcessCov(std_weight_position_ * mean_scale, std_weight_velocity_ * mean_scale);
auto mat = bbox_filter_.predict();
auto x = mat.ptr<float>();
bbox = get_bbox({x[0], x[1]}, {x[2], x[3]});
}
return {bbox, pts};
}
std::pair<Bbox, Points> TrackingFilter::Correct(const Bbox& bbox, const Points& kpts,
const vector<bool>& tracked) {
auto mean_scale = get_mean_scale(bbox_filter_.statePre.at<float>(2), //
bbox_filter_.statePre.at<float>(3));
const auto n = pt_filters_.size();
Points corr_kpts(n);
for (int i = 0; i < n; ++i) {
if (!tracked.empty() && tracked[i]) {
SetKeyPointMeasurementCov(i, std_weight_position_ * mean_scale);
std::array<float, 2> m{kpts[i].x, kpts[i].y};
auto mat = pt_filters_[i].correct(as_mat(m));
corr_kpts[i].x = mat.at<float>(0);
corr_kpts[i].y = mat.at<float>(1);
} else {
ResetKeyPoint(i, kpts[i], mean_scale);
corr_kpts[i] = kpts[i];
}
}
Bbox corr_bbox;
{
SetBboxMeasurementCov(std_weight_position_ * mean_scale);
auto c = get_center(bbox);
auto s = get_scale(bbox);
std::array<float, 4> m{c.x, c.y, s[0], s[1]};
auto mat = bbox_filter_.correct(as_mat(m));
auto x = mat.ptr<float>();
corr_bbox = get_bbox({x[0], x[1]}, {x[2], x[3]});
}
return {corr_bbox, corr_kpts};
}
float TrackingFilter::BboxDistance(const Bbox& bbox) {
auto mean_scale = get_mean_scale(bbox_filter_.statePre.at<float>(2), //
bbox_filter_.statePre.at<float>(3));
SetBboxMeasurementCov(std_weight_position_ * mean_scale);
auto c = get_center(bbox);
auto s = get_scale(bbox);
std::array<float, 4> m{c.x, c.y, s[0], s[1]};
cv::Mat z = as_mat(m);
auto& f = bbox_filter_;
cv::Mat sigma;
cv::gemm(f.measurementMatrix * f.errorCovPre, f.measurementMatrix, 1, f.measurementNoiseCov, 1,
sigma, cv::GEMM_2_T);
cv::Mat r = z - f.measurementMatrix * f.statePre;
// ignore contribution of scales as it is unstable when inferred from key-points
r.at<float>(2) = 0;
r.at<float>(3) = 0;
cv::Mat d = r.t() * sigma.inv() * r;
return d.at<float>();
}
vector<float> TrackingFilter::KeyPointDistance(const Points& kpts) {
auto mean_scale = get_mean_scale(bbox_filter_.statePre.at<float>(2), //
bbox_filter_.statePre.at<float>(3));
const auto n = pt_filters_.size();
vector<float> dists(n);
for (int i = 0; i < n; ++i) {
SetKeyPointMeasurementCov(i, std_weight_position_ * mean_scale);
std::array<float, 2> m{kpts[i].x, kpts[i].y};
cv::Mat z = as_mat(m);
auto& f = pt_filters_[i];
cv::Mat sigma;
cv::gemm(f.measurementMatrix * f.errorCovPre, f.measurementMatrix, 1, f.measurementNoiseCov, 1,
sigma, cv::GEMM_2_T);
cv::Mat r = z - f.measurementMatrix * f.statePre;
cv::Mat d = r.t() * sigma.inv() * r;
dists[i] = d.at<float>();
}
return dists;
}
void TrackingFilter::SetBboxProcessCov(float sigma_p, float sigma_v) {
auto& m = bbox_filter_.processNoiseCov;
cv::setIdentity(m(cv::Rect(0, 0, 4, 4)), sigma_p * sigma_p);
cv::setIdentity(m(cv::Rect(4, 4, 4, 4)), sigma_v * sigma_v);
}
void TrackingFilter::SetBboxMeasurementCov(float sigma_p) {
auto& m = bbox_filter_.measurementNoiseCov;
cv::setIdentity(m, sigma_p * sigma_p);
}
void TrackingFilter::SetBboxErrorCov(float sigma_p, float sigma_v) {
auto& m = bbox_filter_.errorCovPost;
cv::setIdentity(m(cv::Rect(0, 0, 4, 4)), sigma_p * sigma_p);
cv::setIdentity(m(cv::Rect(4, 4, 4, 4)), sigma_v * sigma_v);
}
void TrackingFilter::SetBboxTransitionMat() {
auto& m = bbox_filter_.transitionMatrix;
cv::setIdentity(m(cv::Rect(4, 0, 4, 4))); // with scale velocity
// cv::setIdentity(m(cv::Rect(4, 0, 2, 2))); // w/o scale velocity
}
void TrackingFilter::SetBboxMeasurementMat() {
auto& m = bbox_filter_.measurementMatrix;
cv::setIdentity(m(cv::Rect(0, 0, 4, 4)));
}
void TrackingFilter::SetKeyPointProcessCov(int index, float sigma_p, float sigma_v) {
auto& m = pt_filters_[index].processNoiseCov;
m.at<float>(0, 0) = sigma_p * sigma_p;
m.at<float>(1, 1) = sigma_p * sigma_p;
m.at<float>(2, 2) = sigma_v * sigma_v;
m.at<float>(3, 3) = sigma_v * sigma_v;
}
void TrackingFilter::SetKeyPointMeasurementCov(int index, float sigma_p) {
auto& m = pt_filters_[index].measurementNoiseCov;
m.at<float>(0, 0) = sigma_p * sigma_p;
m.at<float>(1, 1) = sigma_p * sigma_p;
}
void TrackingFilter::SetKeyPointErrorCov(int index, float sigma_p, float sigma_v) {
auto& m = pt_filters_[index].errorCovPost;
m.at<float>(0, 0) = sigma_p * sigma_p;
m.at<float>(1, 1) = sigma_p * sigma_p;
m.at<float>(2, 2) = sigma_v * sigma_v;
m.at<float>(3, 3) = sigma_v * sigma_v;
}
void TrackingFilter::SetKeyPointTransitionMat(int index) {
auto& m = pt_filters_[index].transitionMatrix;
cv::setIdentity(m(cv::Rect(2, 0, 2, 2)));
}
void TrackingFilter::SetKeyPointMeasurementMat(int index) {
auto& m = pt_filters_[index].measurementMatrix;
cv::setIdentity(m(cv::Rect(0, 0, 2, 2)));
}
void TrackingFilter::ResetKeyPoint(int index, const Point& kpt, float scale) {
auto& f = pt_filters_[index];
SetKeyPointErrorCov(index, 2 * std_weight_position_ * scale, 10 * std_weight_velocity_ * scale);
f.statePost.at<float>(0) = kpt.x;
f.statePost.at<float>(1) = kpt.y;
}
} // namespace mmdeploy::mmpose::_pose_tracker
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_CODEBASE_MMPOSE_POSE_TRACKER_TRACKING_FILTER_H
#define MMDEPLOY_CODEBASE_MMPOSE_POSE_TRACKER_TRACKING_FILTER_H
#include "opencv2/video/video.hpp"
#include "pose_tracker/utils.h"
namespace mmdeploy::mmpose::_pose_tracker {
// use Kalman filter to estimate and predict true states
class TrackingFilter {
public:
TrackingFilter(const Bbox& bbox, const vector<Point>& kpts, float std_weight_position,
float std_weight_velocity);
std::pair<Bbox, Points> Predict();
vector<float> KeyPointDistance(const Points& kpts);
float BboxDistance(const Bbox& bbox);
std::pair<Bbox, Points> Correct(const Bbox& bbox, const Points& kpts,
const vector<bool>& tracked);
private:
void SetBboxProcessCov(float sigma_p, float sigma_v);
void SetBboxMeasurementCov(float sigma_p);
void SetBboxErrorCov(float sigma_p, float sigma_v);
void SetBboxTransitionMat();
void SetBboxMeasurementMat();
void SetKeyPointProcessCov(int index, float sigma_p, float sigma_v);
void SetKeyPointMeasurementCov(int index, float sigma_p);
void SetKeyPointErrorCov(int index, float sigma_p, float sigma_v);
void SetKeyPointTransitionMat(int index);
void SetKeyPointMeasurementMat(int index);
void ResetKeyPoint(int index, const Point& kpt, float scale);
private:
std::vector<cv::KalmanFilter> pt_filters_;
cv::KalmanFilter bbox_filter_;
float std_weight_position_;
float std_weight_velocity_;
};
} // namespace mmdeploy::mmpose::_pose_tracker
#endif // MMDEPLOY_TRACKING_FILTER_H
// Copyright (c) OpenMMLab. All rights reserved.
#include "pose_tracker/utils.h"
namespace mmdeploy::mmpose::_pose_tracker {
vector<std::tuple<int, int, float>> greedy_assignment(const vector<float>& scores,
vector<int>& is_valid_row,
vector<int>& is_valid_col, float thr) {
const auto n_rows = is_valid_row.size();
const auto n_cols = is_valid_col.size();
vector<std::tuple<int, int, float>> assignment;
assignment.reserve(std::max(n_rows, n_cols));
while (true) {
auto max_score = std::numeric_limits<float>::lowest();
int max_row = -1;
int max_col = -1;
for (int i = 0; i < n_rows; ++i) {
if (is_valid_row[i]) {
for (int j = 0; j < n_cols; ++j) {
if (is_valid_col[j]) {
if (scores[i * n_cols + j] > max_score) {
max_score = scores[i * n_cols + j];
max_row = i;
max_col = j;
}
}
}
}
}
if (max_score < thr) {
break;
}
is_valid_row[max_row] = 0;
is_valid_col[max_col] = 0;
assignment.emplace_back(max_row, max_col, max_score);
}
return assignment;
}
float intersection_over_union(const Bbox& a, const Bbox& b) {
auto x1 = std::max(a[0], b[0]);
auto y1 = std::max(a[1], b[1]);
auto x2 = std::min(a[2], b[2]);
auto y2 = std::min(a[3], b[3]);
auto inter_area = std::max(0.f, x2 - x1) * std::max(0.f, y2 - y1);
auto a_area = get_area(a);
auto b_area = get_area(b);
auto union_area = a_area + b_area - inter_area;
if (union_area == 0.f) {
return 0;
}
return inter_area / union_area;
}
float object_keypoint_similarity(const Points& pts_a, const Bbox& box_a, const Points& pts_b,
const Bbox& box_b, const vector<float>& sigmas) {
assert(pts_a.size() == sigmas.size());
assert(pts_b.size() == sigmas.size());
auto scale = [](const Bbox& bbox) -> float {
auto a = bbox[2] - bbox[0];
auto b = bbox[3] - bbox[1];
return std::sqrt(a * a + b * b);
};
auto oks = [](const Point& pa, const Point& pb, float s, float k) {
return std::exp(-(pa - pb).dot(pa - pb) / (2.f * s * s * k * k));
};
auto sum = 0.f;
const auto s = .5f * (scale(box_a) + scale(box_b));
for (int i = 0; i < sigmas.size(); ++i) {
sum += oks(pts_a[i], pts_b[i], s, sigmas[i]);
}
sum /= static_cast<float>(sigmas.size());
return sum;
}
std::optional<Bbox> keypoints_to_bbox(const Points& keypoints, const Scores& scores, float img_h,
float img_w, float scale, float kpt_thr, int min_keypoints) {
int valid = 0;
auto x1 = static_cast<float>(img_w);
auto y1 = static_cast<float>(img_h);
auto x2 = 0.f;
auto y2 = 0.f;
for (size_t i = 0; i < keypoints.size(); ++i) {
auto& kpt = keypoints[i];
if (scores[i] >= kpt_thr) {
x1 = std::min(x1, kpt.x);
y1 = std::min(y1, kpt.y);
x2 = std::max(x2, kpt.x);
y2 = std::max(y2, kpt.y);
++valid;
}
}
if (min_keypoints < 0) {
min_keypoints = (static_cast<int>(scores.size()) + 1) / 2;
}
if (valid < min_keypoints) {
return std::nullopt;
}
auto xc = .5f * (x1 + x2);
auto yc = .5f * (y1 + y2);
auto w = (x2 - x1) * scale;
auto h = (y2 - y1) * scale;
return std::array<float, 4>{
std::max(0.f, std::min(img_w, xc - .5f * w)),
std::max(0.f, std::min(img_h, yc - .5f * h)),
std::max(0.f, std::min(img_w, xc + .5f * w)),
std::max(0.f, std::min(img_h, yc + .5f * h)),
};
}
Bbox map_bbox(const Bbox& box) {
Point p0(box[0], box[1]);
Point p1(box[2], box[3]);
auto c = .5f * (p0 + p1);
auto s = p1 - p0;
static constexpr std::array image_size{192.f, 256.f};
float aspect_ratio = image_size[0] * 1.0 / image_size[1];
if (s.x > aspect_ratio * s.y) {
s.y = s.x / aspect_ratio;
} else if (s.x < aspect_ratio * s.y) {
s.x = s.y * aspect_ratio;
}
s.x *= 1.25f;
s.y *= 1.25f;
p0 = c - .5f * s;
p1 = c + .5f * s;
return {p0.x, p0.y, p1.x, p1.y};
}
} // namespace mmdeploy::mmpose::_pose_tracker
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_CODEBASE_MMPOSE_POSE_TRACKER_UTILS_H
#define MMDEPLOY_CODEBASE_MMPOSE_POSE_TRACKER_UTILS_H
#include <array>
#include <numeric>
#include <optional>
#include <vector>
#include "mmdeploy/core/utils/formatter.h"
#include "opencv2/core/core.hpp"
#include "pose_tracker/common.h"
namespace mmdeploy::mmpose::_pose_tracker {
using std::vector;
using Bbox = std::array<float, 4>;
using Bboxes = vector<Bbox>;
using Point = cv::Point2f;
using Points = vector<cv::Point2f>;
using Score = float;
using Scores = vector<float>;
#define POSE_TRACKER_DEBUG(...) MMDEPLOY_DEBUG(__VA_ARGS__)
// opencv3 can't construct cv::Mat from std::array
template <size_t N>
cv::Mat as_mat(const std::array<float, N>& a) {
return cv::Mat_<float>(a.size(), 1, const_cast<float*>(a.data()));
}
// scale = 1.5, kpt_thr = 0.3
std::optional<Bbox> keypoints_to_bbox(const Points& keypoints, const Scores& scores, float img_h,
float img_w, float scale, float kpt_thr, int min_keypoints);
// xyxy format
float intersection_over_union(const Bbox& a, const Bbox& b);
float object_keypoint_similarity(const Points& pts_a, const Bbox& box_a, const Points& pts_b,
const Bbox& box_b, const vector<float>& sigmas);
template <typename T>
void suppress_non_maximum(const vector<T>& scores, const vector<float>& similarities,
vector<int>& is_valid, float thresh);
inline float get_area(const Bbox& bbox) { return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]); }
inline Point get_center(const Bbox& bbox) {
return {.5f * (bbox[0] + bbox[2]), .5f * (bbox[1] + bbox[3])};
}
inline std::array<float, 2> get_scale(const Bbox& bbox) {
return {bbox[2] - bbox[0], bbox[3] - bbox[1]};
}
inline Bbox get_bbox(const Point& center, const std::array<float, 2>& scale) {
return {
center.x - .5f * scale[0],
center.y - .5f * scale[1],
center.x + .5f * scale[0],
center.y + .5f * scale[1],
};
}
vector<std::tuple<int, int, float>> greedy_assignment(const vector<float>& scores,
vector<int>& is_valid_row,
vector<int>& is_valid_col, float thr);
template <typename T>
inline void suppress_non_maximum(const vector<T>& scores, const vector<float>& similarities,
vector<int>& is_valid, float thresh) {
assert(is_valid.size() == scores.size());
vector<int> indices(scores.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](int i, int j) { return scores[i] > scores[j]; });
// suppress similar samples
for (int i = 0; i < indices.size(); ++i) {
if (auto u = indices[i]; is_valid[u]) {
for (int j = i + 1; j < indices.size(); ++j) {
if (auto v = indices[j]; is_valid[v]) {
if (similarities[u * scores.size() + v] >= thresh) {
is_valid[v] = false;
}
}
}
}
}
}
// TopDownAffine's internal logic for mapping pose model inputs
Bbox map_bbox(const Bbox& box);
} // namespace mmdeploy::mmpose::_pose_tracker
#endif // MMDEPLOY_UTILS_H
// Copyright (c) OpenMMLab. All rights reserved.
#include <cctype>
#include <opencv2/imgproc.hpp>
#include "mmdeploy/core/device.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/serialization.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/core/value.h"
#include "mmdeploy/experimental/module_adapter.h"
#include "mmpose.h"
#include "opencv_utils.h"
namespace mmdeploy::mmpose {
using std::string;
using std::vector;
class SimCCLabelDecode : public MMPose {
public:
explicit SimCCLabelDecode(const Value& config) : MMPose(config) {
if (config.contains("params")) {
auto& params = config["params"];
flip_test_ = params.value("flip_test", flip_test_);
simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_);
export_postprocess_ = params.value("export_postprocess", export_postprocess_);
if (export_postprocess_) {
simcc_split_ratio_ = 1.0;
}
if (params.contains("input_size")) {
from_value(params["input_size"], input_size_);
}
}
}
Result<Value> operator()(const Value& _data, const Value& _prob) {
MMDEPLOY_DEBUG("preprocess_result: {}", _data);
MMDEPLOY_DEBUG("inference_result: {}", _prob);
Device cpu_device{"cpu"};
OUTCOME_TRY(auto simcc_x,
MakeAvailableOnDevice(_prob["simcc_x"].get<Tensor>(), cpu_device, stream()));
OUTCOME_TRY(auto simcc_y,
MakeAvailableOnDevice(_prob["simcc_y"].get<Tensor>(), cpu_device, stream()));
OUTCOME_TRY(stream().Wait());
if (!(simcc_x.shape().size() == 3 && simcc_x.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `simcc_x` tensor, shape: {}, dtype: {}", simcc_x.shape(),
(int)simcc_x.data_type());
return Status(eNotSupported);
}
auto& img_metas = _data["img_metas"];
Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}});
Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}});
float *keypoints_data = nullptr, *scores_data = nullptr;
if (!export_postprocess_) {
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
keypoints_data = keypoints.data<float>();
scores_data = scores.data<float>();
} else {
keypoints_data = simcc_x.data<float>();
scores_data = simcc_y.data<float>();
}
std::vector<float> center;
std::vector<float> scale;
from_value(img_metas["center"], center);
from_value(img_metas["scale"], scale);
PoseDetectorOutput output;
float scale_value = 200, x = -1, y = -1, s = 0;
for (int i = 0; i < simcc_x.shape(1); i++) {
x = *(keypoints_data++) / simcc_split_ratio_;
y = *(keypoints_data++) / simcc_split_ratio_;
s = *(scores_data++);
x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5;
y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;
output.key_points.push_back({{x, y}, s});
}
return to_value(output);
}
void get_simcc_maximum(const Tensor& simcc_x, const Tensor& simcc_y, Tensor& keypoints,
Tensor& scores) {
int K = simcc_x.shape(1);
int N_x = simcc_x.shape(2);
int N_y = simcc_y.shape(2);
for (int i = 0; i < K; i++) {
float* data_x = const_cast<float*>(simcc_x.data<float>()) + i * N_x;
float* data_y = const_cast<float*>(simcc_y.data<float>()) + i * N_y;
cv::Mat mat_x = cv::Mat(N_x, 1, CV_32FC1, data_x);
cv::Mat mat_y = cv::Mat(N_y, 1, CV_32FC1, data_y);
double min_val_x, max_val_x, min_val_y, max_val_y;
cv::Point min_loc_x, max_loc_x, min_loc_y, max_loc_y;
cv::minMaxLoc(mat_x, &min_val_x, &max_val_x, &min_loc_x, &max_loc_x);
cv::minMaxLoc(mat_y, &min_val_y, &max_val_y, &min_loc_y, &max_loc_y);
float s = max_val_x > max_val_y ? max_val_y : max_val_x;
float x = s > 0 ? max_loc_x.y : -1.0;
float y = s > 0 ? max_loc_y.y : -1.0;
float* keypoints_data = keypoints.data<float>() + i * 2;
float* scores_data = scores.data<float>() + i;
*(scores_data) = s;
*(keypoints_data + 0) = x;
*(keypoints_data + 1) = y;
}
}
private:
bool flip_test_{false};
bool export_postprocess_{false};
bool shift_heatmap_{false};
float simcc_split_ratio_{2.0};
std::vector<int> input_size_{192, 256};
};
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMPose, SimCCLabelDecode);
} // namespace mmdeploy::mmpose
// Copyright (c) OpenMMLab. All rights reserved.
#include <array>
#include <set>
#include "mmdeploy/archive/value_archive.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/operation/managed.h"
#include "mmdeploy/operation/vision.h"
#include "mmdeploy/preprocess/transform/transform.h"
using namespace std;
namespace mmdeploy::mmpose {
class TopDownAffine : public transform::Transform {
public:
explicit TopDownAffine(const Value& args) noexcept {
assert(args.contains("image_size"));
from_value(args["image_size"], image_size_);
crop_resize_pad_ =
::mmdeploy::operation::Managed<::mmdeploy::operation::CropResizePad>::Create();
}
~TopDownAffine() override = default;
Result<void> Apply(Value& data) override {
MMDEPLOY_DEBUG("top_down_affine input: {}", data);
auto img = data["img"].get<Tensor>();
// prepare data
vector<float> bbox;
vector<float> c; // center
vector<float> s; // scale
if (data.contains("center") && data.contains("scale")) {
// after mmpose v0.26.0
from_value(data["center"], c);
from_value(data["scale"], s);
from_value(data["bbox"], bbox);
} else {
// before mmpose v0.26.0
from_value(data["bbox"], bbox);
Box2cs(bbox, c, s);
}
// end prepare data
Tensor dst;
{
s[0] *= 200;
s[1] *= 200;
const std::array img_roi{0, 0, (int)img.shape(2), (int)img.shape(1)};
const std::array tmp_roi{0, 0, (int)image_size_[0], (int)image_size_[1]};
auto roi = round({c[0] - s[0] / 2.f, c[1] - s[1] / 2.f, s[0], s[1]});
auto src_roi = intersect(roi, img_roi);
// prior scale factor
auto factor = (float)image_size_[0] / s[0];
// rounded dst roi
auto dst_roi = round({(src_roi[0] - roi[0]) * factor, //
(src_roi[1] - roi[1]) * factor, //
src_roi[2] * factor, //
src_roi[3] * factor});
dst_roi = intersect(dst_roi, tmp_roi);
// exact scale factors
auto factor_x = (float)dst_roi[2] / src_roi[2];
auto factor_y = (float)dst_roi[3] / src_roi[3];
// center of src roi
auto c_src_x = src_roi[0] + (src_roi[2] - 1) / 2.f;
auto c_src_y = src_roi[1] + (src_roi[3] - 1) / 2.f;
// center of dst roi
auto c_dst_x = dst_roi[0] + (dst_roi[2] - 1) / 2.f;
auto c_dst_y = dst_roi[1] + (dst_roi[3] - 1) / 2.f;
// vector from c_dst to (w/2, h/2)
auto v_dst_x = image_size_[0] / 2.f - c_dst_x;
auto v_dst_y = image_size_[1] / 2.f - c_dst_y;
// vector from c_src to corrected center
auto v_src_x = v_dst_x / factor_x;
auto v_src_y = v_dst_y / factor_y;
// corrected center
c[0] = c_src_x + v_src_x;
c[1] = c_src_y + v_src_y;
// corrected scale
s[0] = image_size_[0] / factor_x / 200.f;
s[1] = image_size_[1] / factor_y / 200.f;
vector<int> crop_rect = {src_roi[1], src_roi[0], src_roi[1] + src_roi[3] - 1,
src_roi[0] + src_roi[2] - 1};
vector<int> target_size = {dst_roi[2], dst_roi[3]};
vector<int> pad_rect = {dst_roi[1], dst_roi[0], image_size_[1] - dst_roi[3] - dst_roi[1],
image_size_[0] - dst_roi[2] - dst_roi[0]};
crop_resize_pad_.Apply(img, crop_rect, target_size, pad_rect, dst);
}
data["img"] = std::move(dst);
data["img_shape"] = {1, image_size_[1], image_size_[0], img.shape(3)};
data["center"] = to_value(c);
data["scale"] = to_value(s);
MMDEPLOY_DEBUG("output: {}", data);
return success();
}
static std::array<int, 4> round(const std::array<float, 4>& a) {
return {
static_cast<int>(std::round(a[0])),
static_cast<int>(std::round(a[1])),
static_cast<int>(std::round(a[2])),
static_cast<int>(std::round(a[3])),
};
}
// xywh
template <typename T>
static std::array<T, 4> intersect(std::array<T, 4> a, std::array<T, 4> b) {
auto x1 = std::max(a[0], b[0]);
auto y1 = std::max(a[1], b[1]);
a[2] = std::min(a[0] + a[2], b[0] + b[2]) - x1;
a[3] = std::min(a[1] + a[3], b[1] + b[3]) - y1;
a[0] = x1;
a[1] = y1;
if (a[2] <= 0 || a[3] <= 0) {
a = {};
}
return a;
}
void Box2cs(vector<float>& box, vector<float>& center, vector<float>& scale) {
// bbox_xywh2cs
float x = box[0];
float y = box[1];
float w = box[2];
float h = box[3];
float aspect_ratio = image_size_[0] * 1.0 / image_size_[1];
center.push_back(x + w * 0.5);
center.push_back(y + h * 0.5);
if (w > aspect_ratio * h) {
h = w * 1.0 / aspect_ratio;
} else if (w < aspect_ratio * h) {
w = h * aspect_ratio;
}
scale.push_back(w / 200 * 1.25);
scale.push_back(h / 200 * 1.25);
}
protected:
vector<int> image_size_;
::mmdeploy::operation::Managed<::mmdeploy::operation::CropResizePad> crop_resize_pad_;
};
MMDEPLOY_REGISTER_TRANSFORM(TopDownAffine);
} // namespace mmdeploy::mmpose
// Copyright (c) OpenMMLab. All rights reserved.
#include <vector>
#include "mmdeploy/archive/value_archive.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/preprocess/transform/transform.h"
using namespace std;
namespace mmdeploy::mmpose {
class TopDownGetBboxCenterScale : public transform::Transform {
public:
explicit TopDownGetBboxCenterScale(const Value& args) {
padding_ = args.value("padding", 1.25);
assert(args.contains("image_size"));
from_value(args["image_size"], image_size_);
}
~TopDownGetBboxCenterScale() override = default;
Result<void> Apply(Value& data) override {
vector<float> bbox;
from_value(data["bbox"], bbox);
vector<float> c; // center
vector<float> s; // scale
Box2cs(bbox, c, s, padding_, pixel_std_);
data["center"] = to_value(c);
data["scale"] = to_value(s);
return success();
}
void Box2cs(vector<float>& box, vector<float>& center, vector<float>& scale, float padding,
float pixel_std) {
// bbox_xywh2cs
float x = box[0];
float y = box[1];
float w = box[2];
float h = box[3];
float aspect_ratio = image_size_[0] * 1.0 / image_size_[1];
center.push_back(x + w * 0.5);
center.push_back(y + h * 0.5);
if (w > aspect_ratio * h) {
h = w * 1.0 / aspect_ratio;
} else if (w < aspect_ratio * h) {
w = h * aspect_ratio;
}
scale.push_back(w / pixel_std * padding);
scale.push_back(h / pixel_std * padding);
}
protected:
float padding_{1.25f};
float pixel_std_{200.f};
vector<int> image_size_;
};
MMDEPLOY_REGISTER_TRANSFORM(TopDownGetBboxCenterScale);
} // namespace mmdeploy::mmpose
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_mmrotate)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils)
add_library(mmdeploy::mmrotate ALIAS ${PROJECT_NAME})
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} rotated_detector CACHE INTERNAL "")
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/codebase/mmrotate/mmrotate.h"
namespace mmdeploy::mmrotate {
MMDEPLOY_REGISTER_CODEBASE(MMRotate);
} // namespace mmdeploy::mmrotate
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_MMROTATE_H
#define MMDEPLOY_MMROTATE_H
#include <array>
#include "mmdeploy/codebase/common.h"
#include "mmdeploy/core/device.h"
#include "mmdeploy/core/module.h"
namespace mmdeploy::mmrotate {
struct RotatedDetectorOutput {
struct Detection {
int label_id;
float score;
std::array<float, 5> rbbox; // cx,cy,w,h,ag
MMDEPLOY_ARCHIVE_MEMBERS(label_id, score, rbbox);
};
std::vector<Detection> detections;
MMDEPLOY_ARCHIVE_MEMBERS(detections);
};
MMDEPLOY_DECLARE_CODEBASE(MMRotate, mmrotate);
} // namespace mmdeploy::mmrotate
#endif // MMDEPLOY_MMROTATE_H
// Copyright (c) OpenMMLab. All rights reserved.
#include <opencv2/imgproc.hpp>
#include "mmdeploy/core/device.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/serialization.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/core/value.h"
#include "mmrotate.h"
#include "opencv_utils.h"
namespace mmdeploy::mmrotate {
using std::vector;
class ResizeRBBox : public MMRotate {
public:
explicit ResizeRBBox(const Value& cfg) : MMRotate(cfg) {
if (cfg.contains("params")) {
score_thr_ = cfg["params"].value("score_thr", 0.05f);
}
}
Result<Value> operator()(const Value& prep_res, const Value& infer_res) {
MMDEPLOY_DEBUG("prep_res: {}", prep_res);
MMDEPLOY_DEBUG("infer_res: {}", infer_res);
Device cpu_device{"cpu"};
OUTCOME_TRY(auto dets,
MakeAvailableOnDevice(infer_res["dets"].get<Tensor>(), cpu_device, stream_));
OUTCOME_TRY(auto labels,
MakeAvailableOnDevice(infer_res["labels"].get<Tensor>(), cpu_device, stream_));
OUTCOME_TRY(stream_.Wait());
if (!(dets.shape().size() == 3 && dets.shape(2) == 6 && dets.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(),
(int)dets.data_type());
return Status(eNotSupported);
}
if (labels.shape().size() != 2) {
MMDEPLOY_ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
(int)labels.data_type());
return Status(eNotSupported);
}
OUTCOME_TRY(auto result, DispatchGetBBoxes(prep_res["img_metas"], dets, labels));
return to_value(result);
}
Result<RotatedDetectorOutput> DispatchGetBBoxes(const Value& prep_res, const Tensor& dets,
const Tensor& labels) {
auto data_type = labels.data_type();
switch (data_type) {
case DataType::kFLOAT:
return GetRBBoxes<float>(prep_res, dets, labels);
case DataType::kINT32:
return GetRBBoxes<int32_t>(prep_res, dets, labels);
case DataType::kINT64:
return GetRBBoxes<int64_t>(prep_res, dets, labels);
default:
return Status(eNotSupported);
}
}
template <typename T>
Result<RotatedDetectorOutput> GetRBBoxes(const Value& prep_res, const Tensor& dets,
const Tensor& labels) {
RotatedDetectorOutput objs;
auto* dets_ptr = dets.data<float>();
auto* labels_ptr = labels.data<T>();
vector<float> scale_factor;
if (prep_res.contains("scale_factor")) {
from_value(prep_res["scale_factor"], scale_factor);
} else {
scale_factor = {1.f, 1.f, 1.f, 1.f};
}
int ori_width = prep_res["ori_shape"][2].get<int>();
int ori_height = prep_res["ori_shape"][1].get<int>();
auto bboxes_number = dets.shape()[1];
auto channels = dets.shape()[2];
for (auto i = 0; i < bboxes_number; ++i, dets_ptr += channels, ++labels_ptr) {
float score = dets_ptr[channels - 1];
if (score <= score_thr_) {
continue;
}
auto cx = dets_ptr[0] / scale_factor[0];
auto cy = dets_ptr[1] / scale_factor[1];
auto width = dets_ptr[2] / scale_factor[0];
auto height = dets_ptr[3] / scale_factor[1];
auto angle = dets_ptr[4];
RotatedDetectorOutput::Detection det{};
det.label_id = static_cast<int>(*labels_ptr);
det.score = score;
det.rbbox = {cx, cy, width, height, angle};
objs.detections.push_back(std::move(det));
}
return objs;
}
private:
float score_thr_;
};
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMRotate, ResizeRBBox);
} // namespace mmdeploy::mmrotate
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_mmseg)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy_opencv_utils
mmdeploy_operation)
add_library(mmdeploy::mmseg ALIAS ${PROJECT_NAME})
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} segmentor CACHE INTERNAL "")
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/codebase/mmseg/mmseg.h"
namespace mmdeploy::mmseg {
MMDEPLOY_REGISTER_CODEBASE(MMSegmentation);
} // namespace mmdeploy::mmseg
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_MMSEG_H
#define MMDEPLOY_MMSEG_H
#include "mmdeploy/codebase/common.h"
#include "mmdeploy/core/device.h"
#include "mmdeploy/core/module.h"
#include "mmdeploy/core/tensor.h"
namespace mmdeploy::mmseg {
struct SegmentorOutput {
Tensor mask;
Tensor score;
int height;
int width;
int classes;
MMDEPLOY_ARCHIVE_MEMBERS(mask, score, height, width, classes);
};
MMDEPLOY_DECLARE_CODEBASE(MMSegmentation, mmseg);
} // namespace mmdeploy::mmseg
#endif // MMDEPLOY_MMSEG_H
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/codebase/mmseg/mmseg.h"
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/operation/managed.h"
#include "mmdeploy/operation/vision.h"
#include "mmdeploy/preprocess/transform/transform.h"
#include "opencv_utils.h"
namespace mmdeploy::mmseg {
// TODO: resize masks on device
// TODO: when network output is on device, cast it to a smaller type (e.g. int16_t or int8_t
// according to num classes) to reduce DtoH footprint
class ResizeMask : public MMSegmentation {
public:
explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) {
try {
classes_ = cfg["params"]["num_classes"].get<int>();
with_argmax_ = cfg["params"].value("with_argmax", true);
little_endian_ = IsLittleEndian();
::mmdeploy::operation::Context ctx(Device("cpu"), stream_);
permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
} catch (const std::exception &e) {
MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg);
throw_exception(eInvalidArgument);
}
}
Result<Value> operator()(const Value &preprocess_result, const Value &inference_result) {
MMDEPLOY_DEBUG("preprocess: {}\ninference: {}", preprocess_result, inference_result);
auto mask = inference_result["output"].get<Tensor>();
MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(),
mask.shape(), mask.data_type());
if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) {
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape());
return Status(eNotSupported);
}
if ((mask.shape(1) != 1) && with_argmax_) {
MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`",
mask.shape());
return Status(eNotSupported);
}
if ((mask.data_type() != DataType::kFLOAT) && !with_argmax_) {
MMDEPLOY_ERROR("probability feat map only support float32 output");
return Status(eNotSupported);
}
auto channel = (int)mask.shape(1);
auto height = (int)mask.shape(2);
auto width = (int)mask.shape(3);
auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get<int>();
auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get<int>();
Device host{"cpu"};
OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_));
OUTCOME_TRY(stream().Wait()); // should sync even mask is on cpu
if (!with_argmax_) {
// (C, H, W) -> (H, W, C)
::mmdeploy::operation::Context ctx(host, stream_);
std::vector<int> axes = {0, 2, 3, 1};
OUTCOME_TRY(permute_.Apply(host_tensor, host_tensor, axes));
}
OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type(), channel));
cv::Mat mask_mat(height, width, cv_type, host_tensor.data());
cv::Mat resized_mask;
cv::Mat resized_score;
Tensor tensor_mask{};
Tensor tensor_score{};
if (with_argmax_) {
// mask
if (mask_mat.channels() > 1) {
cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1);
}
if (mask_mat.type() != CV_32S) {
mask_mat.convertTo(mask_mat, CV_32S);
}
resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest");
tensor_mask = cpu::CVMat2Tensor(resized_mask);
} else {
// score
resized_score = cpu::Resize(mask_mat, input_height, input_width, "bilinear");
tensor_score = cpu::CVMat2Tensor(resized_score);
std::vector<int> axes = {0, 3, 1, 2};
::mmdeploy::operation::Context ctx(host, stream_);
OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes));
}
SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_};
return to_value(output);
}
private:
static Result<int> GetCvType(DataType type, int channel) {
switch (type) {
case DataType::kFLOAT:
return CV_32FC(channel);
case DataType::kINT64:
return CV_32SC2;
case DataType::kINT32:
return CV_32S;
default:
return Status(eNotSupported);
}
}
static bool IsLittleEndian() {
union Un {
char a;
int b;
} un;
un.b = 1;
return (int)un.a == 1;
}
protected:
::mmdeploy::operation::Managed<::mmdeploy::operation::Permute> permute_;
int classes_{};
bool with_argmax_{true};
bool little_endian_;
};
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMSegmentation, ResizeMask);
} // namespace mmdeploy::mmseg
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_core)
# this is used to keep compatibility with legacy spdlog where CMake package is not available
set(SPDLOG_LIB)
if (MMDEPLOY_SPDLOG_EXTERNAL)
find_package(spdlog QUIET)
if (spdlog_FOUND)
set(SPDLOG_LIB spdlog::spdlog)
endif ()
else ()
set(MMDEPLOY_SPDLOG_DIR ${CMAKE_SOURCE_DIR}/third_party/spdlog)
add_subdirectory(${MMDEPLOY_SPDLOG_DIR} ${CMAKE_CURRENT_BINARY_DIR}/spdlog EXCLUDE_FROM_ALL)
set_target_properties(spdlog PROPERTIES POSITION_INDEPENDENT_CODE ON)
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
target_compile_options(spdlog PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
endif ()
set(SPDLOG_LIB spdlog::spdlog)
mmdeploy_export(spdlog)
install(DIRECTORY ${MMDEPLOY_SPDLOG_DIR}/include/spdlog
DESTINATION include/mmdeploy/third_party)
endif ()
set(SRCS
device_impl.cpp
logger.cpp
mat.cpp
model.cpp
module.cpp
net.cpp
operator.cpp
status_code.cpp
tensor.cpp
registry.cpp
graph.cpp
utils/device_utils.cpp
utils/formatter.cpp
utils/stacktrace.cpp
profiler.cpp
)
mmdeploy_add_library(${PROJECT_NAME} ${SRCS})
target_include_directories(${PROJECT_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/csrc>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/third_party/outcome>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/third_party/concurrentqueue>
# TODO: remove dependency of `json`
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/third_party/json>
)
if (MSVC)
target_compile_options(${PROJECT_NAME} PUBLIC
$<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor;/Zc:__cplusplus>)
endif ()
if (MMDEPLOY_STATUS_USE_STACKTRACE)
include(${CMAKE_SOURCE_DIR}/cmake/stacktrace.cmake)
else ()
target_compile_definitions(${PROJECT_NAME} PUBLIC -DMMDEPLOY_STATUS_USE_SOURCE_LOCATION=1)
endif ()
target_include_directories(${PROJECT_NAME} PUBLIC
$<INSTALL_INTERFACE:include>
$<INSTALL_INTERFACE:include/mmdeploy/third_party/outcome>
$<INSTALL_INTERFACE:include/mmdeploy/third_party/json>)
if (NOT MMDEPLOY_SPDLOG_EXTERNAL)
target_include_directories(spdlog INTERFACE
$<INSTALL_INTERFACE:include/mmdeploy/third_party>)
endif ()
target_link_libraries(${PROJECT_NAME} PUBLIC ${SPDLOG_LIB})
include(${CMAKE_SOURCE_DIR}/cmake/filesystem.cmake)
if (STD_FS_LIB)
target_link_libraries(${PROJECT_NAME} PUBLIC ${STD_FS_LIB})
endif ()
add_library(mmdeploy::core ALIAS ${PROJECT_NAME})
install(DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/mmdeploy/core
DESTINATION include/mmdeploy
FILES_MATCHING PATTERN "*.h")
install(FILES ${CMAKE_SOURCE_DIR}/third_party/outcome/outcome-experimental.hpp
DESTINATION include/mmdeploy/third_party/outcome)
install(DIRECTORY ${CMAKE_SOURCE_DIR}/csrc/mmdeploy/experimental
DESTINATION include/mmdeploy
FILES_MATCHING PATTERN "*.h")
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_CORE_ARCHIVE_H_
#define MMDEPLOY_SRC_CORE_ARCHIVE_H_
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/serialization.h"
namespace mmdeploy {
template <typename T, typename A>
using member_load_t = decltype(std::declval<T&>().load(std::declval<A&>()));
template <typename T, typename A>
using member_save_t = decltype(std::declval<T&>().save(std::declval<A&>()));
template <typename T, typename A>
using member_serialize_t = decltype(std::declval<T&>().serialize(std::declval<A&>()));
template <typename T, typename A>
using has_member_load = detail::is_detected<member_load_t, T, A>;
template <typename T, typename A>
using has_member_save = detail::is_detected<member_save_t, T, A>;
template <typename T, typename A>
using has_member_serialize = detail::is_detected<member_serialize_t, T, A>;
template <typename T, typename A>
using adl_load_t = decltype(adl_serializer<T>::load(std::declval<A&>(), std::declval<T&>()));
template <typename T, typename A>
using has_adl_load = detail::is_detected<adl_load_t, T, A>;
template <typename T, typename A>
using adl_save_t = decltype(adl_serializer<T>::save(std::declval<A&>(), std::declval<T&>()));
template <typename T, typename A>
using has_adl_save = detail::is_detected<adl_save_t, T, A>;
template <typename T, typename A>
using adl_serialize_t =
decltype(adl_serializer<T>::serialize(std::declval<A&>(), std::declval<T&>()));
template <typename T, typename A>
using has_adl_serialize = detail::is_detected<adl_serialize_t, T, A>;
namespace detail {
// ADL bridge for archives
class ArchiveBase {};
} // namespace detail
template <typename Archive>
class OutputArchive : public detail::ArchiveBase {
public:
template <typename... Args>
void operator()(Args&&... args) {
(dispatch(std::forward<Args>(args)), ...);
}
private:
template <typename T>
void dispatch(T&& v) {
auto& archive = static_cast<Archive&>(*this);
if constexpr (has_member_save<T, Archive>::value) {
std::forward<T>(v).save(archive);
} else if constexpr (has_member_serialize<T, Archive>::value) {
std::forward<T>(v).serialize(archive);
} else if constexpr (has_adl_save<T, Archive>::value) {
adl_serializer<T>::save(archive, std::forward<T>(v));
} else if constexpr (has_adl_serialize<T, Archive>::value) {
adl_serializer<T>::serialize(archive, std::forward<T>(v));
} else {
archive.native(std::forward<T>(v));
}
}
};
template <typename Archive>
class InputArchive : public detail::ArchiveBase {
public:
template <typename... Args>
void operator()(Args&&... args) {
(dispatch(std::forward<Args>(args)), ...);
}
private:
template <typename T>
void dispatch(T&& v) {
auto& archive = static_cast<Archive&>(*this);
if constexpr (has_member_load<T, Archive>::value) {
std::forward<T>(v).load(archive);
} else if constexpr (has_member_serialize<T, Archive>::value) {
std::forward<T>(v).serialize(archive);
} else if constexpr (has_adl_load<T, Archive>::value) {
adl_serializer<T>::load(archive, std::forward<T>(v));
} else if constexpr (has_adl_serialize<T, Archive>::value) {
adl_serializer<T>::serialize(archive, std::forward<T>(v));
} else {
archive.native(std::forward<T>(v));
}
}
};
} // namespace mmdeploy
#endif // MMDEPLOY_SRC_CORE_ARCHIVE_H_
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
#include "mmdeploy/core/macro.h"
#include "mmdeploy/core/mpl/type_traits.h"
#include "mmdeploy/core/status_code.h"
#include "mmdeploy/core/utils/formatter.h"
namespace mmdeploy {
namespace framework {
class Platform;
class Device;
class Stream;
class Event;
class Allocator;
class Buffer;
class Kernel;
class PlatformImpl;
class StreamImpl;
class EventImpl;
class AllocatorImpl;
class BufferImpl;
class KernelImpl;
template <typename T>
using optional = std::optional<T>;
class DeviceId {
public:
using ValueType = int32_t;
constexpr explicit DeviceId(ValueType value) : value_(value) {}
constexpr operator ValueType() const { return value_; } // NOLINT
constexpr ValueType get() const { return value_; }
private:
ValueType value_;
};
class PlatformId {
public:
using ValueType = int32_t;
constexpr explicit PlatformId(ValueType value) : value_(value) {}
constexpr operator ValueType() const { return value_; } // NOLINT
constexpr ValueType get() const { return value_; }
private:
ValueType value_;
};
class Device {
public:
constexpr Device() : platform_id_(-1), device_id_(-1) {}
constexpr explicit Device(DeviceId device_id, PlatformId platform_id = PlatformId(-1))
: Device(platform_id.get(), device_id.get()) {}
constexpr explicit Device(PlatformId platform_id, DeviceId device_id = DeviceId(-1))
: Device(platform_id.get(), device_id.get()) {}
constexpr explicit Device(int platform_id, int device_id = 0)
: platform_id_(platform_id), device_id_(device_id) {}
MMDEPLOY_API explicit Device(const char* platform_name, int device_id = 0);
constexpr int device_id() const noexcept { return device_id_; }
constexpr int platform_id() const noexcept { return platform_id_; }
constexpr bool is_host() const noexcept { return platform_id() == 0; }
constexpr bool is_device() const noexcept { return platform_id() > 0; }
constexpr bool operator==(const Device& other) const noexcept {
return platform_id_ == other.platform_id_ && device_id_ == other.device_id_;
}
constexpr bool operator!=(const Device& other) const noexcept { return !(*this == other); }
constexpr explicit operator bool() const noexcept { return platform_id_ >= 0 && device_id_ >= 0; }
constexpr operator DeviceId() const noexcept { // NOLINT
return DeviceId(device_id_);
}
constexpr operator PlatformId() const noexcept { // NOLINT
return PlatformId(platform_id_);
}
friend std::ostream& operator<<(std::ostream& os, const Device& device) {
os << "(" << device.platform_id_ << ", " << device.device_id_ << ")";
return os;
}
private:
int platform_id_{0};
int device_id_{0};
};
enum class MemcpyKind : int { HtoD, DtoH, DtoD };
class MMDEPLOY_API Platform {
public:
// throws if not found
explicit Platform(const char* platform_name);
// throws if not found
explicit Platform(int platform_id);
// bind device with the current thread
Result<void> Bind(Device device, Device* prev);
// -1 if invalid
int GetPlatformId() const;
// "" if invalid
const char* GetPlatformName() const;
bool operator==(const Platform& other) { return impl_ == other.impl_; }
bool operator!=(const Platform& other) { return !(*this == other); }
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
private:
explicit Platform(std::shared_ptr<PlatformImpl> impl) : impl_(std::move(impl)) {}
private:
friend class PlatformRegistry;
friend class Access;
std::shared_ptr<PlatformImpl> impl_;
};
MMDEPLOY_API const char* GetPlatformName(PlatformId id);
class DeviceGuard {
public:
explicit DeviceGuard(Device device) : platform_(device.platform_id()) {
auto r = platform_.Bind(device, &prev_);
if (!r) {
MMDEPLOY_ERROR("failed to bind device {}: {}", device, r.error().message().c_str());
}
}
~DeviceGuard() {
auto r = platform_.Bind(prev_, nullptr);
if (!r) {
MMDEPLOY_ERROR("failed to unbind device {}: {}", prev_, r.error().message().c_str());
}
}
private:
Platform platform_;
Device prev_;
};
class MMDEPLOY_API Stream {
public:
Stream() = default;
explicit Stream(Device device, uint64_t flags = 0);
explicit Stream(Device device, void* native, uint64_t flags = 0);
explicit Stream(Device device, std::shared_ptr<void> native, uint64_t flags = 0);
Device GetDevice() const;
Result<void> Query();
Result<void> Wait();
Result<void> DependsOn(Event& event);
Result<void> Submit(Kernel& kernel);
void* GetNative(ErrorCode* ec = nullptr);
Result<void> Copy(const Buffer& src, Buffer& dst, size_t size = -1, size_t src_offset = 0,
size_t dst_offset = 0);
Result<void> Copy(const void* host_ptr, Buffer& dst, size_t size = -1, size_t dst_offset = 0);
Result<void> Copy(const Buffer& src, void* host_ptr, size_t size = -1, size_t src_offset = 0);
Result<void> Fill(const Buffer& dst, void* pattern, size_t pattern_size, size_t size = -1,
size_t offset = 0);
bool operator==(const Stream& other) const { return impl_ == other.impl_; }
bool operator!=(const Stream& other) const { return !(*this == other); }
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
static Stream GetDefault(Device device);
private:
explicit Stream(std::shared_ptr<StreamImpl> impl) : impl_(std::move(impl)) {}
private:
friend class Access;
std::shared_ptr<StreamImpl> impl_;
};
template <typename T>
T GetNative(Stream& stream, ErrorCode* ec = nullptr) {
return reinterpret_cast<T>(stream.GetNative(ec));
}
class MMDEPLOY_API Event {
public:
Event() = default;
explicit Event(Device device, uint64_t flags = 0);
explicit Event(Device device, void* native, uint64_t flags = 0);
explicit Event(Device device, std::shared_ptr<void> native, uint64_t flags = 0);
Device GetDevice();
Result<void> Query();
Result<void> Wait();
Result<void> Record(Stream& stream);
void* GetNative(ErrorCode* ec = nullptr);
bool operator==(const Event& other) const { return impl_ == other.impl_; }
bool operator!=(const Event& other) const { return !(*this == other); }
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
private:
explicit Event(std::shared_ptr<EventImpl> impl) : impl_(std::move(impl)) {}
private:
friend class Access;
std::shared_ptr<EventImpl> impl_;
};
template <typename T>
T GetNative(Event& event, ErrorCode* ec = nullptr) {
return reinterpret_cast<T>(event.GetNative(ec));
}
class MMDEPLOY_API Kernel {
public:
Kernel() = default;
explicit Kernel(std::shared_ptr<KernelImpl> impl) : impl_(std::move(impl)) {}
Device GetDevice() const;
void* GetNative(ErrorCode* ec = nullptr);
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
private:
std::shared_ptr<KernelImpl> impl_;
};
template <typename T>
T GetNative(Kernel& kernel, ErrorCode* ec = nullptr) {
return reinterpret_cast<T>(kernel.GetNative(ec));
}
class MMDEPLOY_API Allocator {
friend class Access;
public:
Allocator() = default;
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
private:
explicit Allocator(std::shared_ptr<AllocatorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<AllocatorImpl> impl_;
};
class MMDEPLOY_API Buffer {
public:
Buffer() = default;
Buffer(Device device, size_t size, size_t alignment = 1, uint64_t flags = 0)
: Buffer(device, size, Allocator{}, alignment, flags) {}
Buffer(Device device, size_t size, Allocator allocator, size_t alignment = 1, uint64_t flags = 0);
Buffer(Device device, size_t size, void* native, uint64_t flags = 0);
Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags = 0);
// create sub-buffer
Buffer(Buffer& buffer, size_t offset, size_t size, uint64_t flags = 0);
size_t GetSize(ErrorCode* ec = nullptr) const;
// bool IsSubBuffer(ErrorCode* ec = nullptr);
void* GetNative(ErrorCode* ec = nullptr) const;
Device GetDevice() const;
Allocator GetAllocator() const;
bool operator==(const Buffer& other) const { return impl_ == other.impl_; }
bool operator!=(const Buffer& other) const { return !(*this == other); }
explicit operator bool() const noexcept { return static_cast<bool>(impl_); }
private:
explicit Buffer(std::shared_ptr<BufferImpl> impl) : impl_(std::move(impl)) {}
private:
friend class Access;
std::shared_ptr<BufferImpl> impl_;
};
template <typename T>
T GetNative(Buffer& buffer, ErrorCode* ec = nullptr) {
return reinterpret_cast<T>(buffer.GetNative(ec));
}
template <typename T>
T GetNative(const Buffer& buffer, ErrorCode* ec = nullptr) {
return reinterpret_cast<T>(buffer.GetNative(ec));
}
class MMDEPLOY_API PlatformRegistry {
public:
using Creator = std::function<std::shared_ptr<PlatformImpl>()>;
int Register(Creator creator);
int AddAlias(const char* name, const char* target);
int GetPlatform(const char* name, Platform* platform);
int GetPlatform(int id, Platform* platform);
int GetPlatformId(const char* name);
PlatformImpl* GetPlatformImpl(PlatformId id);
private:
int GetNextId();
bool IsAvailable(int id);
private:
struct Entry {
std::string name;
int id;
Platform platform;
};
std::vector<Entry> entries_;
std::vector<std::pair<std::string, std::string>> aliases_;
};
MMDEPLOY_API PlatformRegistry& gPlatformRegistry();
} // namespace framework
MMDEPLOY_REGISTER_TYPE_ID(framework::Device, 1);
MMDEPLOY_REGISTER_TYPE_ID(framework::Buffer, 2);
MMDEPLOY_REGISTER_TYPE_ID(framework::Stream, 3);
MMDEPLOY_REGISTER_TYPE_ID(framework::Event, 4);
} // namespace mmdeploy
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/core/device_impl.h"
#include <cassert>
#include "mmdeploy/core/device.h"
#include "mmdeploy/core/logger.h"
namespace mmdeploy::framework {
template <typename T>
T SetError(ErrorCode* ec, ErrorCode code, T ret) {
if (ec) {
*ec = code;
}
return ret;
}
////////////////////////////////////////////////////////////////////////////////
/// Device
Device::Device(const char* platform_name, int device_id) {
platform_id_ = gPlatformRegistry().GetPlatformId(platform_name);
device_id_ = device_id;
}
//////////////////////////////////////////////////
/// Platform
int Platform::GetPlatformId() const {
if (impl_) {
return impl_->GetPlatformId();
}
return -1;
}
const char* Platform::GetPlatformName() const {
if (impl_) {
return impl_->GetPlatformName();
}
return "";
}
Platform::Platform(const char* platform_name) {
if (-1 == gPlatformRegistry().GetPlatform(platform_name, this)) {
throw_exception(eInvalidArgument);
}
}
Platform::Platform(int platform_id) {
if (-1 == gPlatformRegistry().GetPlatform(platform_id, this)) {
throw_exception(eInvalidArgument);
}
}
Result<void> Platform::Bind(Device device, Device* prev) { return impl_->BindDevice(device, prev); }
const char* GetPlatformName(PlatformId id) {
if (auto impl = gPlatformRegistry().GetPlatformImpl(id); impl) {
return impl->GetPlatformName();
}
return nullptr;
}
////////////////////////////////////////////////////////////////////////////////
/// Buffer
Buffer::Buffer(Device device, size_t size, Allocator allocator, size_t alignment, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
impl_ = p->CreateBuffer(device);
if (auto r = impl_->Init(size, std::move(allocator), alignment, flags); r.has_error()) {
impl_.reset();
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Buffer::Buffer(Device device, size_t size, void* native, uint64_t flags)
: Buffer(device, size, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Buffer::Buffer(Device device, size_t size, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
impl_ = p->CreateBuffer(device);
if (auto r = impl_->Init(size, std::move(native), flags); r.has_error()) {
impl_.reset();
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Device Buffer::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
Allocator Buffer::GetAllocator() const { return impl_ ? impl_->GetAllocator() : Allocator{}; }
void* Buffer::GetNative(ErrorCode* ec) const {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
size_t Buffer::GetSize(ErrorCode* ec) const {
return impl_ ? impl_->GetSize(ec) : SetError(ec, eInvalidArgument, 0);
}
Buffer::Buffer(Buffer& buffer, size_t offset, size_t size, uint64_t flags) {
auto impl = buffer.impl_->SubBuffer(offset, size, flags);
if (!impl) {
impl.error().throw_exception();
}
impl_ = std::move(impl).value();
}
#if 0
int Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
Stream stream;
GetDefaultStream(dst.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(host_ptr, dst, size, dst_offset);
}
int Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
Stream stream;
GetDefaultStream(src.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(src, host_ptr, size, src_offset);
}
int Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
size_t dst_offset) {
Stream stream;
GetDefaultStream(src.GetDevice(), &stream);
if (!stream) {
return Status(eFail);
}
return stream.Copy(src, dst, size, src_offset, dst_offset);
}
#endif
//////////////////////////////////////////////////
/// Stream
Stream::Stream(Device device, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateStream(device);
if (auto r = impl->Init(flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
MMDEPLOY_ERROR("{}, {}", device.device_id(), device.platform_id());
throw_exception(eInvalidArgument);
}
}
Stream::Stream(Device device, void* native, uint64_t flags)
: Stream(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Stream::Stream(Device device, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateStream(device);
if (auto r = impl->Init(std::move(native), flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Result<void> Stream::Query() {
if (impl_) {
return impl_->Query();
}
return Status(eInvalidArgument);
}
Result<void> Stream::Wait() {
if (impl_) {
return impl_->Wait();
}
return Status(eInvalidArgument);
}
Result<void> Stream::DependsOn(Event& event) {
return impl_ ? impl_->DependsOn(event) : Status(eInvalidArgument);
}
void* Stream::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
Result<void> Stream::Submit(Kernel& kernel) {
return impl_ ? impl_->Submit(kernel) : Status(eInvalidArgument);
}
Result<void> Stream::Copy(const Buffer& src, Buffer& dst, size_t size, size_t src_offset,
size_t dst_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = src.GetSize();
}
if (auto p = GetPlatformImpl(GetDevice())) {
return p->Copy(src, dst, size, src_offset, dst_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Copy(const void* host_ptr, Buffer& dst, size_t size, size_t dst_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = dst.GetSize();
}
auto device = GetDevice();
if (auto p = GetPlatformImpl(device)) {
return p->Copy(host_ptr, dst, size, dst_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Copy(const Buffer& src, void* host_ptr, size_t size, size_t src_offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
if (size == static_cast<size_t>(-1)) {
size = src.GetSize();
}
if (auto p = GetPlatformImpl(GetDevice())) {
return p->Copy(src, host_ptr, size, src_offset, *this);
}
return Status(eInvalidArgument);
}
Result<void> Stream::Fill(const Buffer& dst, void* pattern, size_t pattern_size, size_t size,
size_t offset) {
if (!impl_) {
return Status(eInvalidArgument);
}
return Status(eNotSupported);
}
Device Stream::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
Stream Stream::GetDefault(Device device) {
Platform platform(device.platform_id());
assert(platform);
Stream stream = Access::get<PlatformImpl>(platform).GetDefaultStream(device.device_id()).value();
return stream;
}
/////////////////////////////////////////////////
/// Event
Event::Event(Device device, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateEvent(device);
if (auto r = impl->Init(flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Event::Event(Device device, void* native, uint64_t flags)
: Event(device, std::shared_ptr<void>(native, [](void*) {}), flags) {}
Event::Event(Device device, std::shared_ptr<void> native, uint64_t flags) {
if (auto p = GetPlatformImpl(device)) {
auto impl = p->CreateEvent(device);
if (auto r = impl->Init(std::move(native), flags)) {
impl_ = std::move(impl);
} else {
r.error().throw_exception();
}
} else {
throw_exception(eInvalidArgument);
}
}
Result<void> Event::Query() { return impl_ ? impl_->Query() : Status(eInvalidArgument); }
Result<void> Event::Wait() { return impl_ ? impl_->Wait() : Status(eInvalidArgument); }
Result<void> Event::Record(Stream& stream) {
return impl_ ? impl_->Record(stream) : Status(eInvalidArgument);
}
void* Event::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
Device Event::GetDevice() { return impl_ ? impl_->GetDevice() : Device{}; }
/////////////////////////////////////////////////
/// Kernel
Device Kernel::GetDevice() const { return impl_ ? impl_->GetDevice() : Device{}; }
void* Kernel::GetNative(ErrorCode* ec) {
return impl_ ? impl_->GetNative(ec) : SetError(ec, eInvalidArgument, nullptr);
}
/////////////////////////////////////////////////
/// PlatformRegistry
int PlatformRegistry::Register(Creator creator) {
Platform platform(creator());
auto proposed_id = platform.GetPlatformId();
std::string name = platform.GetPlatformName();
if (proposed_id == -1) {
proposed_id = GetNextId();
platform.impl_->SetPlatformId(proposed_id);
} else if (!IsAvailable(proposed_id)) {
return -1;
}
entries_.push_back({name, proposed_id, platform});
return 0;
}
int PlatformRegistry::AddAlias(const char* name, const char* target) {
aliases_.emplace_back(name, target);
return 0;
}
int PlatformRegistry::GetNextId() {
for (int i = 1;; ++i) {
if (IsAvailable(i)) {
return i;
}
}
}
bool PlatformRegistry::IsAvailable(int id) {
for (const auto& entry : entries_) {
if (entry.id == id) {
return false;
}
}
return true;
}
int PlatformRegistry::GetPlatform(const char* name, Platform* platform) {
for (const auto& alias : aliases_) {
if (name == alias.first) {
name = alias.second.c_str();
break;
}
}
for (const auto& entry : entries_) {
if (entry.name == name) {
*platform = entry.platform;
return 0;
}
}
return -1;
}
int PlatformRegistry::GetPlatform(int id, Platform* platform) {
for (const auto& entry : entries_) {
if (entry.id == id) {
*platform = entry.platform;
return 0;
}
}
return -1;
}
int PlatformRegistry::GetPlatformId(const char* name) {
for (const auto& alias : aliases_) {
if (name == alias.first) {
name = alias.second.c_str();
break;
}
}
for (const auto& entry : entries_) {
if (entry.name == name) {
return entry.id;
}
}
return -1;
}
PlatformImpl* PlatformRegistry::GetPlatformImpl(PlatformId id) {
for (const auto& entry : entries_) {
if (entry.id == id) {
return entry.platform.impl_.get();
}
}
return nullptr;
}
PlatformRegistry& gPlatformRegistry() {
static PlatformRegistry instance;
return instance;
}
} // namespace mmdeploy::framework
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_CORE_DEVICE_IMPL_H_
#define MMDEPLOY_SRC_CORE_DEVICE_IMPL_H_
#include "mmdeploy/core/device.h"
namespace mmdeploy::framework {
using std::shared_ptr;
using PlatformImplPtr = shared_ptr<PlatformImpl>;
using AllocatorImplPtr = shared_ptr<AllocatorImpl>;
using BufferImplPtr = shared_ptr<BufferImpl>;
using StreamImplPtr = shared_ptr<StreamImpl>;
using EventImplPtr = shared_ptr<EventImpl>;
class PlatformImpl {
public:
PlatformImpl() : platform_id_(-1) {}
virtual ~PlatformImpl() = default;
virtual const char* GetPlatformName() const noexcept = 0;
virtual int GetPlatformId() const noexcept { return platform_id_; }
virtual void SetPlatformId(int id) { platform_id_ = id; }
virtual Result<void> BindDevice(Device device, Device* prev) = 0;
virtual shared_ptr<BufferImpl> CreateBuffer(Device device) = 0;
virtual shared_ptr<StreamImpl> CreateStream(Device device) = 0;
virtual shared_ptr<EventImpl> CreateEvent(Device device) = 0;
virtual Result<void> Copy(const void* host_ptr, Buffer dst, size_t size, size_t dst_offset,
Stream stream) = 0;
virtual Result<void> Copy(Buffer src, void* host_ptr, size_t size, size_t src_offset,
Stream stream) = 0;
virtual Result<void> Copy(Buffer src, Buffer dst, size_t size, size_t src_offset,
size_t dst_offset, Stream stream) = 0;
virtual Result<Stream> GetDefaultStream(int32_t device_id) = 0;
protected:
int platform_id_;
};
class AllocatorImpl {
public:
struct Block {
explicit Block(void* _handle = nullptr, size_t _size = 0) : handle(_handle), size(_size) {}
void* handle;
size_t size;
};
virtual ~AllocatorImpl() = default;
virtual Block Allocate(size_t size) noexcept = 0;
virtual void Deallocate(Block& block) noexcept = 0;
virtual bool Owns(const Block& block) const noexcept = 0;
virtual const char* Name() const noexcept { return ""; }
// virtual Device device() const noexcept = 0;
};
// create, destroy, sub, MakeAvailableOnDevice, FromHost, fill, copy, map, unmap
class BufferImpl {
public:
explicit BufferImpl(Device device) : device_(device) {}
virtual ~BufferImpl() = default;
virtual Result<void> Init(size_t size, Allocator allocator, size_t alignment, uint64_t flags) = 0;
virtual Result<void> Init(size_t size, std::shared_ptr<void> native, uint64_t flags) = 0;
virtual Result<shared_ptr<BufferImpl>> SubBuffer(size_t offset, size_t size, uint64_t flags) = 0;
virtual size_t GetSize(ErrorCode* ec) = 0;
virtual Allocator GetAllocator() const = 0;
virtual void* GetNative(ErrorCode* ec) = 0;
Device GetDevice() const noexcept { return device_; }
protected:
Device device_;
};
class StreamImpl {
public:
explicit StreamImpl(Device device) : device_(device) {}
virtual ~StreamImpl() = default;
virtual Result<void> Init(uint64_t flags) = 0;
virtual Result<void> Init(std::shared_ptr<void> native, uint64_t flags) = 0;
virtual Result<void> Query() = 0;
virtual Result<void> Wait() = 0;
virtual Result<void> Submit(Kernel& kernel) = 0;
virtual Result<void> DependsOn(Event& event) = 0;
virtual void* GetNative(ErrorCode* ec) = 0;
Device GetDevice() const noexcept { return device_; }
protected:
Device device_;
};
class EventImpl {
public:
explicit EventImpl(Device device) : device_(device) {}
virtual ~EventImpl() = default;
virtual Result<void> Init(uint64_t flags) = 0;
virtual Result<void> Init(std::shared_ptr<void> native, uint64_t flags) = 0;
virtual Result<void> Query() = 0;
virtual Result<void> Record(Stream& st) = 0;
virtual Result<void> Wait() = 0;
virtual void* GetNative(ErrorCode* ec) = 0;
Device GetDevice() const noexcept { return device_; }
protected:
Device device_;
};
class KernelWrapper {
public:
virtual ~KernelWrapper() = default;
virtual int Invoke(const std::vector<void*>& args) = 0;
};
class KernelImpl {
public:
explicit KernelImpl(Device device) : device_(device) {}
virtual ~KernelImpl() = default;
Device GetDevice() const noexcept { return device_; }
virtual void* GetNative(ErrorCode* ec) = 0;
protected:
Device device_;
};
struct Access {
template <typename T, typename Obj>
static T& get(const Obj& obj) {
return static_cast<T&>(*obj.impl_);
}
template <typename Obj>
static auto& get_impl(const Obj& obj) {
return obj.impl_;
}
template <typename T, typename... Args>
static T create(Args&&... args) {
return T(std::forward<Args>(args)...);
}
};
inline PlatformImpl* GetPlatformImpl(const Device& device) {
return gPlatformRegistry().GetPlatformImpl(device);
}
} // namespace mmdeploy::framework
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment