test_resource.h 4.63 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_TEST_RESOURCE_H
#define MMDEPLOY_TEST_RESOURCE_H
#include <algorithm>
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <vector>

#include "mmdeploy/core/utils/filesystem.h"
#include "test_define.h"

using namespace std;

class MMDeployTestResources {
 public:
  static MMDeployTestResources &Get() {
    static MMDeployTestResources resource;
    return resource;
  }

  const std::vector<std::string> &device_names() const { return devices_; }
  const std::vector<std::string> &device_names(const std::string &backend) const {
    return backend_devices_.at(backend);
  }
  const std::vector<std::string> &backends() const { return backends_; }
  const std::vector<std::string> &codebases() const { return codebases_; }
  const fs::path &resource_root_path() const { return resource_root_path_; }

  bool HasDevice(const std::string &name) const {
    return std::any_of(devices_.begin(), devices_.end(),
                       [&](const std::string &device_name) { return device_name == name; });
  }

  bool IsDir(const fs::path &dir_name) const {
    auto path = resource_root_path_ / dir_name;
    return fs::is_directory(path);
  }

  bool IsFile(const fs::path &file_name) const {
    auto path = resource_root_path_ / file_name;
    return fs::is_regular_file(path);
  }

 public:
  std::vector<std::string> LocateModelResources(const fs::path &sdk_model_zoo_dir) {
    std::vector<std::string> sdk_model_list;
    if (resource_root_path_.empty()) {
      return sdk_model_list;
    }

    auto path = resource_root_path_ / sdk_model_zoo_dir;
    if (!fs::is_directory(path)) {
      return sdk_model_list;
    }
    for (auto const &dir_entry : fs::directory_iterator{path}) {
      fs::directory_entry entry{dir_entry.path()};
      if (auto const &_path = dir_entry.path(); fs::is_directory(_path)) {
        sdk_model_list.push_back(dir_entry.path().string());
      }
    }
    return sdk_model_list;
  }

  std::vector<std::string> LocateImageResources(const fs::path &img_dir) {
    std::vector<std::string> img_list;

    if (resource_root_path_.empty()) {
      return img_list;
    }

    auto path = resource_root_path_ / img_dir;
    if (!fs::is_directory(path)) {
      return img_list;
    }

    set<string> extensions{".png", ".jpg", ".jpeg", ".bmp"};
    for (auto const &dir_entry : fs::directory_iterator{path}) {
      if (!fs::is_regular_file(dir_entry.path())) {
        std::cout << dir_entry.path().string() << std::endl;
        continue;
      }
      auto const &_path = dir_entry.path();
      auto ext = _path.extension().string();
      std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
      if (extensions.find(ext) != extensions.end()) {
        img_list.push_back(_path.string());
      }
    }
    return img_list;
  }

 private:
  MMDeployTestResources() {
    devices_ = Split(kDevices);
    backends_ = Split(kBackends);
    codebases_ = Split(kCodebases);
    backend_devices_["pplnn"] = {"cpu", "cuda"};
    backend_devices_["trt"] = {"cuda"};
    backend_devices_["ort"] = {"cpu"};
    backend_devices_["ncnn"] = {"cpu"};
    backend_devices_["openvino"] = {"cpu"};
    resource_root_path_ = LocateResourceRootPath(fs::current_path(), 8);
  }

  static std::vector<std::string> Split(const std::string &text, char delimiter = ';') {
    std::vector<std::string> result;
    std::istringstream ss(text);
    for (std::string word; std::getline(ss, word, delimiter);) {
      result.emplace_back(word);
    }
    return result;
  }

  fs::path LocateResourceRootPath(const fs::path &cur_path, int max_depth) {
    if (max_depth < 0) {
      return "";
    }
    for (auto const &dir_entry : fs::directory_iterator{cur_path}) {
      fs::directory_entry entry{dir_entry.path()};
      auto const &_path = dir_entry.path();
      // filename must be checked before fs::is_directory, the latter will throw
      // when _path points to a system file on Windows
      if (_path.filename() == "mmdeploy_test_resources" && fs::is_directory(_path)) {
        return _path;
      }
    }
    // Didn't find 'mmdeploy_test_resources' in current directory.
    // Move to its parent directory and keep looking for it
    if (cur_path.has_parent_path()) {
      return LocateResourceRootPath(cur_path.parent_path(), max_depth - 1);
    } else {
      return "";
    }
  }

 private:
  std::vector<std::string> devices_;
  std::vector<std::string> backends_;
  std::vector<std::string> codebases_;
  std::map<std::string, std::vector<std::string>> backend_devices_;
  fs::path resource_root_path_;
  //  std::string resource_root_path_;
};

#endif  // MMDEPLOY_TEST_RESOURCE_H