model_lifecycle.h 11.5 KB
Newer Older
xiabo's avatar
xiabo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//  * Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
//  * Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
//  * Neither the name of NVIDIA CORPORATION nor the names of its
//    contributors may be used to endorse or promote products derived
//    from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once

#include <functional>
#include <map>
#include <mutex>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "repo_agent.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "triton/common/thread_pool.h"

namespace triton { namespace core {

struct ModelLifeCycleOptions {
  explicit ModelLifeCycleOptions(
      const double min_compute_capability,
      const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map,
      const triton::common::HostPolicyCmdlineConfigMap& host_policy_map,
      const unsigned int model_load_thread_count)
      : min_compute_capability_(min_compute_capability),
        backend_cmdline_config_map_(backend_cmdline_config_map),
        host_policy_map_(host_policy_map),
        model_load_thread_count_(model_load_thread_count)
  {
  }
  // The minimum supported CUDA compute capability.
  const double min_compute_capability_;
  // The backend configuration settings specified on the command-line
  const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map_;
  // The host policy setting used when loading models.
  const triton::common::HostPolicyCmdlineConfigMap& host_policy_map_;
  // Number of the threads to use for concurrently loading models
  const unsigned int model_load_thread_count_;
};


/// Readiness status for models.
enum class ModelReadyState {
  // The model is in an unknown state. The model is not available for
  // inferencing.
  UNKNOWN,

  // The model is ready and available for inferencing.
  READY,

  // The model is unavailable, indicating that the model failed to
  // load or has been implicitly or explicitly unloaded. The model is
  // not available for inferencing.
  UNAVAILABLE,

  // The model is being loaded by the inference server. The model is
  // not available for inferencing.
  LOADING,

  // The model is being unloaded by the inference server. The model is
  // not available for inferencing.
  UNLOADING
};

/// Get the string representation for a ModelReadyState
const std::string& ModelReadyStateString(ModelReadyState state);

using VersionStateMap =
    std::map<int64_t, std::pair<ModelReadyState, std::string>>;
using ModelStateMap = std::map<std::string, VersionStateMap>;

// Helper class to manage the lifecycle of a list of associated agent models
class TritonRepoAgentModelList {
 public:
  TritonRepoAgentModelList()
      : last_action_type_(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE){};
  ~TritonRepoAgentModelList()
  {
    // Using destructor to finish the unload lifecycle without
    // explicitly managing the last step in ModelLifecycle.
    if (last_action_type_ == TRITONREPOAGENT_ACTION_UNLOAD) {
      InvokeAgentModels(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE);
    }
  }
  Status AddAgentModel(std::unique_ptr<TritonRepoAgentModel>&& agent_model)
  {
    agent_models_.emplace_back(std::move(agent_model));
    return Status::Success;
  }

  size_t Size() { return agent_models_.size(); }

  TritonRepoAgentModel* Back() { return agent_models_.back().get(); }

  Status InvokeAgentModels(const TRITONREPOAGENT_ActionType action_type)
  {
    // Special handling for the current model lifecycle implementation,
    // the repo agent may be asked to perform UNLOAD action multiple times,
    // and the requests after the first should be ignored.
    const bool first_unload =
        (action_type == TRITONREPOAGENT_ACTION_UNLOAD) &&
        (last_action_type_ != TRITONREPOAGENT_ACTION_UNLOAD);
    if (!first_unload) {
      return Status::Success;
    }

    last_action_type_ = action_type;
    switch (action_type) {
      case TRITONREPOAGENT_ACTION_LOAD:
      case TRITONREPOAGENT_ACTION_UNLOAD: {
        for (size_t idx = 0; idx < agent_models_.size(); ++idx) {
          RETURN_IF_ERROR(agent_models_[idx]->InvokeAgent(action_type));
        }
        break;
      }
      case TRITONREPOAGENT_ACTION_LOAD_COMPLETE:
      case TRITONREPOAGENT_ACTION_LOAD_FAIL:
      case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: {
        // reverse order
        for (size_t one_pass_idx = agent_models_.size(); one_pass_idx > 0;
             --one_pass_idx) {
          RETURN_IF_ERROR(
              agent_models_[one_pass_idx - 1]->InvokeAgent(action_type));
        }
        break;
      }
    }
    return Status::Success;
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentModelList);

  std::vector<std::unique_ptr<TritonRepoAgentModel>> agent_models_;
  TRITONREPOAGENT_ActionType last_action_type_;
};

class InferenceServer;
class Model;

class ModelLifeCycle {
 public:
  static Status Create(
      InferenceServer* server, const ModelLifeCycleOptions& options,
      std::unique_ptr<ModelLifeCycle>* life_cycle);

  ~ModelLifeCycle()
  {
    // Explicitly clean up thread pool first to clean up any pending callbacks
    // that may modify model lifecycle members
    load_pool_.reset();
    map_.clear();
  }

  // Start loading model with specified versions asynchronously.
  // All versions that are being served will be unloaded only after
  // the load is finished sucessfully.
  Status AsyncLoad(
      const std::string& model_name, const std::string& model_path,
      const inference::ModelConfig& model_config, const bool is_config_provided,
      const std::shared_ptr<TritonRepoAgentModelList>& agent_model_list,
      std::function<void(Status)>&& OnComplete);

  // Unload model asynchronously.
  Status AsyncUnload(const std::string& model_name);

  // Get specified version of the model. Latest ready version will
  // be retrieved if 'version' is -1. Return error if the version specified is
  // not found or it is not ready.
  Status GetModel(
      const std::string& model_name, const int64_t version,
      std::shared_ptr<Model>* model);

  // Get the ModelStateMap representation of the live models. A model is
  // live if at least one of the versions is not unknown nor unavailable.
  // If 'strict_readiness' is true, a model is only live if
  // at least one of the versions is ready.
  const ModelStateMap LiveModelStates(bool strict_readiness = false);

  // Get the ModelStateMap representation of the models.
  const ModelStateMap ModelStates();

  // Get the VersionStateMap representation of the specified model.
  const VersionStateMap VersionStates(const std::string& model_name);

  // Get the state of a specific model version.
  Status ModelState(
      const std::string& model_name, const int64_t model_version,
      ModelReadyState* state);

  // Instruct the model to stop accepting new inference requests.
  Status StopAllModels();

  // Return the number of in-flight inference if any, model versions
  // that don't have in-flight inferences will not be included.
  const std::set<std::tuple<std::string, int64_t, size_t>> InflightStatus();

 private:
  struct ModelInfo {
    ModelInfo(
        const std::string& model_path,
        const inference::ModelConfig& model_config,
        const uint64_t last_update_ns)
        : model_config_(model_config), model_path_(model_path),
#ifdef TRITON_ENABLE_ENSEMBLE
          is_ensemble_(model_config.platform() == kEnsemblePlatform),
#else
          is_ensemble_(false),
#endif  // TRITON_ENABLE_ENSEMBLE
          last_update_ns_(last_update_ns), state_(ModelReadyState::UNKNOWN)
    {
    }

    // Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in
    // model state. Note that 'mtx_' should be acquired before invoking this
    // function to prevent possible data race.
    void Release()
    {
      state_ = ModelReadyState::UNLOADING;
      state_reason_.clear();
      agent_model_list_.reset();
      model_.reset();
    }

    const inference::ModelConfig model_config_;
    const std::string model_path_;
    const bool is_ensemble_;

    std::mutex mtx_;

    uint64_t last_update_ns_;

    ModelReadyState state_;
    std::string state_reason_;

    // flyweight
    std::shared_ptr<TritonRepoAgentModelList> agent_model_list_;
    std::shared_ptr<Model> model_;
  };

  struct LoadTracker {
    LoadTracker(
        const size_t affected_version_cnt, const uint64_t last_update_ns)
        : last_update_ns_(last_update_ns),
          affected_version_cnt_(affected_version_cnt), load_failed_(false),
          completed_version_cnt_(0)
    {
    }

    const uint64_t last_update_ns_;
    const size_t affected_version_cnt_;

    std::mutex mtx_;

    bool load_failed_;
    std::string reason_;
    size_t completed_version_cnt_;
    std::map<int64_t, ModelInfo*> load_set_;
  };

  ModelLifeCycle(InferenceServer* server, const ModelLifeCycleOptions& options)
      : server_(server),
        min_compute_capability_(options.min_compute_capability_),
        cmdline_config_map_(options.backend_cmdline_config_map_),
        host_policy_map_(options.host_policy_map_)
  {
    load_pool_.reset(new triton::common::ThreadPool(
        std::max(1u, options.model_load_thread_count_)));
  }

  void CreateModel(
      const std::string& model_name, const int64_t version,
      ModelInfo* model_info, const bool is_config_provided);
  // Callback function template for model load.
  // 'OnComplete' needs to be passed by value for now as there can be
  // multiple versions to be loaded and each holds a copy of
  // the 'OnComplete' callback.
  void OnLoadComplete(
      const std::string& model_name, const int64_t version,
      ModelInfo* model_info, std::function<void(Status)> OnComplete,
      std::shared_ptr<LoadTracker> load_tracker);


  // Mutex for 'map_' and 'background_models_'
  std::mutex map_mtx_;

  using VersionMap = std::map<int64_t, std::unique_ptr<ModelInfo>>;
  using ModelMap = std::map<std::string, VersionMap>;
  ModelMap map_;
  // Models that are being loaded / unloaded in background
  std::map<uintptr_t, std::unique_ptr<ModelInfo>> background_models_;

  InferenceServer* server_;
  const double min_compute_capability_;
  const triton::common::BackendCmdlineConfigMap cmdline_config_map_;
  const triton::common::HostPolicyCmdlineConfigMap host_policy_map_;

  // Fixed-size thread pool to load models at specified concurrency
  std::unique_ptr<triton::common::ThreadPool> load_pool_;
};

}}  // namespace triton::core