test_utils.cpp 1.64 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Copyright (c) OpenMMLab. All rights reserved.

#include "test_utils.h"
using namespace std;

namespace mmdeploy::test {
unique_ptr<Transform> CreateTransform(const Value& cfg, Device device, Stream stream) {
  auto op_type = cfg.value<string>("type", "");
  auto op_version = cfg.value<int>("version", -1);

  try {
    auto creator = gRegistry<transform::Transform>().Get(op_type, op_version);
    if (creator == nullptr) {
      return nullptr;
    }
    auto _cfg = cfg;
    _cfg["context"]["device"] = device;
    _cfg["context"]["stream"] = stream;

    operation::Context context(device, stream);
    return std::make_unique<Transform>(creator->Create(_cfg));
  } catch (std::exception& e) {
    cout << "exception: " << e.what() << endl;
    return nullptr;
  } catch (...) {
    cout << "unexpected exception" << endl;
    return nullptr;
  }
}

vector<int64_t> Shape(const Value& value, const string& shape_key) {
  vector<int64_t> shape;
  for (auto& v : value[shape_key]) {
    shape.push_back(v.get<int>());
  }
  return shape;
}

vector<float> ImageNormCfg(const Value& value, const std::string& key) {
  vector<float> res;
  for (auto& v : value["img_norm_cfg"][key]) {
    res.push_back(v.get<float>());
  }
  return res;
}

Transform::Transform(std::unique_ptr<transform::Transform> transform)
    : device_(operation::gContext().device()),
      stream_(operation::gContext().stream()),
      transform_(std::move(transform)) {}

Result<Value> Transform::Process(const Value& input) {
  auto output = input;
  {
    operation::Context context(device_, stream_);
    OUTCOME_TRY(transform_->Apply(output));
  }
  return output;
}

}  // namespace mmdeploy::test