Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenzhuo
lmdeployEx
Commits
374c78ca
Commit
374c78ca
authored
May 21, 2024
by
chenzhuo
Browse files
qwen-1.5
parents
Pipeline
#1012
canceled with stages
Changes
156
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
6475 additions
and
0 deletions
+6475
-0
3rdparty/core-r22.12/src/model_config_utils.h
3rdparty/core-r22.12/src/model_config_utils.h
+282
-0
3rdparty/core-r22.12/src/model_lifecycle.cc
3rdparty/core-r22.12/src/model_lifecycle.cc
+740
-0
3rdparty/core-r22.12/src/model_lifecycle.h
3rdparty/core-r22.12/src/model_lifecycle.h
+324
-0
3rdparty/core-r22.12/src/model_repository_manager.cc
3rdparty/core-r22.12/src/model_repository_manager.cc
+1602
-0
3rdparty/core-r22.12/src/model_repository_manager.h
3rdparty/core-r22.12/src/model_repository_manager.h
+345
-0
3rdparty/core-r22.12/src/numa_utils.cc
3rdparty/core-r22.12/src/numa_utils.cc
+237
-0
3rdparty/core-r22.12/src/numa_utils.h
3rdparty/core-r22.12/src/numa_utils.h
+57
-0
3rdparty/core-r22.12/src/payload.cc
3rdparty/core-r22.12/src/payload.cc
+215
-0
3rdparty/core-r22.12/src/payload.h
3rdparty/core-r22.12/src/payload.h
+102
-0
3rdparty/core-r22.12/src/pinned_memory_manager.cc
3rdparty/core-r22.12/src/pinned_memory_manager.cc
+378
-0
3rdparty/core-r22.12/src/pinned_memory_manager.h
3rdparty/core-r22.12/src/pinned_memory_manager.h
+108
-0
3rdparty/core-r22.12/src/rate_limiter.cc
3rdparty/core-r22.12/src/rate_limiter.cc
+943
-0
3rdparty/core-r22.12/src/rate_limiter.h
3rdparty/core-r22.12/src/rate_limiter.h
+310
-0
3rdparty/core-r22.12/src/repo_agent.cc
3rdparty/core-r22.12/src/repo_agent.cc
+573
-0
3rdparty/core-r22.12/src/repo_agent.h
3rdparty/core-r22.12/src/repo_agent.h
+182
-0
3rdparty/core-r22.12/src/response_allocator.h
3rdparty/core-r22.12/src/response_allocator.h
+77
-0
No files found.
Too many changes to show.
To preserve performance only
156 of 156+
files are displayed.
Plain diff
Email patch
3rdparty/core-r22.12/src/model_config_utils.h
0 → 100644
View file @
374c78ca
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "model_config.pb.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
#include "filesystem.h"
namespace
triton
{
namespace
core
{
/// Enumeration for the different backend types.
enum
BackendType
{
BACKEND_TYPE_UNKNOWN
=
0
,
BACKEND_TYPE_TENSORRT
=
1
,
BACKEND_TYPE_TENSORFLOW
=
2
,
BACKEND_TYPE_ONNXRUNTIME
=
3
,
BACKEND_TYPE_PYTORCH
=
4
};
// Get version of a model from the path containing the model
/// definition file.
/// \param path The path to the model definition file.
/// \param version Returns the version.
/// \return The error status.
Status
GetModelVersionFromPath
(
const
std
::
string
&
path
,
int64_t
*
version
);
/// Get the tensor name, false value, and true value for a boolean
/// sequence batcher control kind. If 'required' is true then must
/// find a tensor for the control. If 'required' is false, return
/// 'tensor_name' as empty-string if the control is not mapped to any
/// tensor.
Status
GetBooleanSequenceControlProperties
(
const
inference
::
ModelSequenceBatching
&
batcher
,
const
std
::
string
&
model_name
,
const
inference
::
ModelSequenceBatching
::
Control
::
Kind
control_kind
,
const
bool
required
,
std
::
string
*
tensor_name
,
inference
::
DataType
*
tensor_datatype
,
float
*
fp32_false_value
,
float
*
fp32_true_value
,
int32_t
*
int32_false_value
,
int32_t
*
int32_true_value
,
bool
*
bool_false_value
,
bool
*
bool_true_value
);
/// Get the tensor name and datatype for a non-boolean sequence
/// batcher control kind. If 'required' is true then must find a
/// tensor for the control. If 'required' is false, return
/// 'tensor_name' as empty-string if the control is not mapped to any
/// tensor. 'tensor_datatype' returns the required datatype for the
/// control.
Status
GetTypedSequenceControlProperties
(
const
inference
::
ModelSequenceBatching
&
batcher
,
const
std
::
string
&
model_name
,
const
inference
::
ModelSequenceBatching
::
Control
::
Kind
control_kind
,
const
bool
required
,
std
::
string
*
tensor_name
,
inference
::
DataType
*
tensor_datatype
);
/// Read a ModelConfig and normalize it as expected by model backends.
/// \param path The full-path to the directory containing the
/// model configuration.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \param config Returns the normalized model configuration.
/// \return The error status.
Status
GetNormalizedModelConfig
(
const
std
::
string
&
model_name
,
const
std
::
string
&
path
,
const
double
min_compute_capability
,
inference
::
ModelConfig
*
config
);
/// Auto-complete backend related fields (platform, backend and default model
/// filename) if not set, note that only Triton recognized backends will be
/// checked.
/// \param model_name The name of the model.
/// \param model_path The full-path to the directory containing the
/// model configuration.
/// \param config Returns the auto-completed model configuration.
/// \return The error status.
Status
AutoCompleteBackendFields
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
inference
::
ModelConfig
*
config
);
/// Detects and adds missing fields in the model configuration.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
/// \param config The model configuration
/// \return The error status
Status
NormalizeModelConfig
(
const
double
min_compute_capability
,
inference
::
ModelConfig
*
config
);
/// [FIXME] better formalize config normalization / validation
/// Detects and adds missing fields in instance group setting.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
/// \param config The model configuration
/// \return The error status
Status
NormalizeInstanceGroup
(
const
double
min_compute_capability
,
const
std
::
vector
<
inference
::
ModelInstanceGroup
>&
preferred_groups
,
inference
::
ModelConfig
*
config
);
/// [FIXME] Remove once a more permanent solution is implemented (DLIS-4211)
/// Localize EXECUTION_ENV_PATH in python backend.
/// \param model_path The full-path to the directory containing the model
/// configuration, before localization.
/// \param config The model configuration
/// \param localized_model_dir The localized model directory
/// \return The error status
Status
LocalizePythonBackendExecutionEnvironmentPath
(
const
std
::
string
&
model_path
,
inference
::
ModelConfig
*
config
,
std
::
shared_ptr
<
LocalizedPath
>*
localized_model_dir
);
/// Auto-complete the instance count based on instance kind and backend name.
/// \param group The instance group to set the count for.
/// \param backend The backend name to check against.
/// \return The error status.
Status
SetDefaultInstanceCount
(
inference
::
ModelInstanceGroup
*
group
,
const
std
::
string
&
backend
);
/// Validate that a model is specified correctly, except for model inputs
/// and outputs. ValidateModelIOConfig() should be called to
/// validate model inputs and outputs.
/// \param config The model configuration to validate.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status
ValidateModelConfig
(
const
inference
::
ModelConfig
&
config
,
const
double
min_compute_capability
);
/// [FIXME] better formalize config normalization / validation
/// Validate instance group setting.
/// \param config The model configuration to validate.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status
ValidateInstanceGroup
(
const
inference
::
ModelConfig
&
config
,
const
double
min_compute_capability
);
/// Validate that a model inputs and outputs are specified correctly.
/// \param config The model configuration to validate.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status
ValidateModelIOConfig
(
const
inference
::
ModelConfig
&
config
);
/// Validate that input is specified correctly in a model
/// configuration.
/// \param io The model input.
/// \param max_batch_size The max batch size specified in model configuration.
/// \param platform The platform name
/// \return The error status. A non-OK status indicates the input
/// is not valid.
Status
ValidateModelInput
(
const
inference
::
ModelInput
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
platform
);
/// Validate that an input matches one of the allowed input names.
/// \param io The model input.
/// \param allowed The set of allowed input names.
/// \return The error status. A non-OK status indicates the input
/// is not valid.
Status
CheckAllowedModelInput
(
const
inference
::
ModelInput
&
io
,
const
std
::
set
<
std
::
string
>&
allowed
);
/// Validate that an output is specified correctly in a model
/// configuration.
/// \param io The model output.
/// \param max_batch_size The max batch size specified in model configuration.
/// \param platform The platform name
/// \return The error status. A non-OK status indicates the output
/// is not valid.
Status
ValidateModelOutput
(
const
inference
::
ModelOutput
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
platform
);
/// Validate that an output matches one of the allowed output names.
/// \param io The model output.
/// \param allowed The set of allowed output names.
/// \return The error status. A non-OK status indicates the output
/// is not valid.
Status
CheckAllowedModelOutput
(
const
inference
::
ModelOutput
&
io
,
const
std
::
set
<
std
::
string
>&
allowed
);
/// Validate that a model batch inputs and batch outputs are specified
/// correctly.
/// \param config The model configuration to validate..
/// \return The error status. A non-OK status indicates the batch inputs or
/// batch outputs are not valid.
Status
ValidateBatchIO
(
const
inference
::
ModelConfig
&
config
);
/// Parse the 'value' of the parameter 'key' into a boolean value.
/// \param key The name of the parameter.
/// \param value The value of the parameter in string.
/// \param parsed_value Return the boolean of the parameter.
/// \return The error status. A non-OK status indicates failure on parsing the
/// value.
Status
ParseBoolParameter
(
const
std
::
string
&
key
,
std
::
string
value
,
bool
*
parsed_value
);
/// Parse the 'value' of the parameter 'key' into a long long integer value.
/// \param key The name of the parameter.
/// \param value The value of the parameter in string.
/// \param parsed_value Return the numerical value of the parameter.
/// \return The error status. A non-OK status indicates failure on parsing the
/// value.
Status
ParseLongLongParameter
(
const
std
::
string
&
key
,
const
std
::
string
&
value
,
int64_t
*
parsed_value
);
/// Obtain the 'profile_index' of the 'profile_name'.
/// \param profile_name The name of the profile.
/// \param profile_index Return the index of the profile.
/// \return The error status. A non-OK status indicates failure on getting the
/// value.
Status
GetProfileIndex
(
const
std
::
string
&
profile_name
,
int
*
profile_index
);
/// Convert a model configuration protobuf to the equivalent json.
/// \param config The protobuf model configuration.
/// \param config_version The model configuration will be returned in
/// a format matching this version. If the configuration cannot be
/// represented in the requested version's format then an error will
/// be returned.
/// \param json Returns the equivalent JSON.
/// \return The error status.
Status
ModelConfigToJson
(
const
inference
::
ModelConfig
&
config
,
const
uint32_t
config_version
,
std
::
string
*
json_str
);
/// Convert a model configuration JSON to the equivalent protobuf.
/// \param config The JSON model configuration.
/// \param config_version The model configuration will be returned in
/// a format matching this version. If the configuration cannot be
/// represented in the requested version's format then an error will
/// be returned.
/// \param protobuf Returns the equivalent protobuf.
/// \return The error status.
Status
JsonToModelConfig
(
const
std
::
string
&
json_config
,
const
uint32_t
config_version
,
inference
::
ModelConfig
*
protobuf_config
);
/// Get the BackendType value for a platform name.
/// \param platform_name The platform name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType
GetBackendTypeFromPlatform
(
const
std
::
string
&
platform_name
);
/// Get the BackendType value for a backend name.
/// \param backend_name The backend name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType
GetBackendType
(
const
std
::
string
&
backend_name
);
/// Get the Triton server data type corresponding to a data type.
/// \param dtype The data type.
/// \return The Triton server data type.
TRITONSERVER_DataType
DataTypeToTriton
(
const
inference
::
DataType
dtype
);
/// Get the data type corresponding to a Triton server data type.
/// \param dtype The Triton server data type.
/// \return The data type.
inference
::
DataType
TritonToDataType
(
const
TRITONSERVER_DataType
dtype
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_lifecycle.cc
0 → 100644
View file @
374c78ca
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_lifecycle.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "repo_agent.h"
#include "triton/common/logging.h"
#include "triton/common/thread_pool.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace
triton
{
namespace
core
{
const
std
::
string
&
ModelReadyStateString
(
ModelReadyState
state
)
{
switch
(
state
)
{
case
ModelReadyState
::
UNKNOWN
:
{
static
std
::
string
m
(
"UNKNOWN"
);
return
m
;
}
case
ModelReadyState
::
READY
:
{
static
std
::
string
m
(
"READY"
);
return
m
;
}
case
ModelReadyState
::
UNAVAILABLE
:
{
static
std
::
string
m
(
"UNAVAILABLE"
);
return
m
;
}
case
ModelReadyState
::
LOADING
:
{
static
std
::
string
m
(
"LOADING"
);
return
m
;
}
case
ModelReadyState
::
UNLOADING
:
{
static
std
::
string
m
(
"UNLOADING"
);
return
m
;
}
}
static
std
::
string
m
(
"<unknown>"
);
return
m
;
}
namespace
{
Status
VersionsToLoad
(
const
std
::
string
model_path
,
const
std
::
string
&
name
,
const
inference
::
ModelConfig
&
model_config
,
std
::
set
<
int64_t
>*
versions
)
{
versions
->
clear
();
// Get integral number of the version directory
std
::
set
<
std
::
string
>
subdirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
model_path
,
&
subdirs
));
std
::
set
<
int64_t
,
std
::
greater
<
int64_t
>>
existing_versions
;
for
(
const
auto
&
subdir
:
subdirs
)
{
if
(
subdir
==
kWarmupDataFolder
||
subdir
==
kInitialStateFolder
)
{
continue
;
}
if
((
subdir
.
length
()
>
1
)
&&
(
subdir
.
front
()
==
'0'
))
{
LOG_WARNING
<<
"ignore version directory '"
<<
subdir
<<
"' which contains leading zeros in its directory name"
;
continue
;
}
try
{
int64_t
version
=
std
::
stoll
(
subdir
);
existing_versions
.
insert
(
version
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
LOG_WARNING
<<
"ignore version directory '"
<<
subdir
<<
"' which fails to convert to integral number"
;
}
}
if
(
model_config
.
version_policy
().
has_specific
())
{
for
(
const
auto
&
v
:
model_config
.
version_policy
().
specific
().
versions
())
{
// Only load the specific versions that are presented in model directory
bool
version_not_exist
=
existing_versions
.
insert
(
v
).
second
;
if
(
!
version_not_exist
)
{
versions
->
emplace
(
v
);
}
else
{
LOG_ERROR
<<
"version "
<<
v
<<
" is specified for model '"
<<
name
<<
"', but the version directory is not present"
;
}
}
}
else
{
if
(
model_config
.
version_policy
().
has_latest
())
{
// std::set is sorted with std::greater
for
(
const
auto
&
v
:
existing_versions
)
{
if
(
versions
->
size
()
>=
model_config
.
version_policy
().
latest
().
num_versions
())
{
break
;
}
versions
->
emplace
(
v
);
}
}
else
{
// all
versions
->
insert
(
existing_versions
.
begin
(),
existing_versions
.
end
());
}
}
return
Status
::
Success
;
}
// Use smart pointer with custom deleter so that model state will be updated
// to UNAVAILABLE if all smart pointer copies are out of scope
struct
ModelDeleter
{
ModelDeleter
(
std
::
function
<
void
()
>
OnDestroyModel
)
:
OnDestroyModel_
(
std
::
move
(
OnDestroyModel
))
{
}
void
operator
()(
Model
*
model
)
{
// The actual model object must be destroyed in a different
// thread. This thread could have a callstack that includes the
// model itself because this deleter could be triggered by
// a request release or response send in the model. Following
// delete will lead to the model destructor which may wait on this
// same thread... so deadlock if we don't use a different thread
// here.
std
::
function
<
void
()
>
destroy_fn
=
OnDestroyModel_
;
std
::
thread
dthd
([
model
,
destroy_fn
]()
{
delete
model
;
destroy_fn
();
});
dthd
.
detach
();
}
// Use to inform the ModelLifeCycle that the model handle is destroyed
std
::
function
<
void
()
>
OnDestroyModel_
;
};
}
// namespace
Status
ModelLifeCycle
::
Create
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
,
std
::
unique_ptr
<
ModelLifeCycle
>*
life_cycle
)
{
std
::
unique_ptr
<
ModelLifeCycle
>
local_life_cycle
(
new
ModelLifeCycle
(
server
,
options
));
*
life_cycle
=
std
::
move
(
local_life_cycle
);
return
Status
::
Success
;
}
const
ModelStateMap
ModelLifeCycle
::
LiveModelStates
(
bool
strict_readiness
)
{
LOG_VERBOSE
(
2
)
<<
"LiveModelStates()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
ModelStateMap
live_model_states
;
for
(
auto
&
model_version
:
map_
)
{
bool
live
=
false
;
VersionStateMap
version_map
;
for
(
auto
&
version_model
:
model_version
.
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
strict_readiness
&&
version_model
.
second
->
state_
!=
ModelReadyState
::
READY
)
{
continue
;
}
// At least one version is live (ready / loading / unloading)
if
((
version_model
.
second
->
state_
!=
ModelReadyState
::
UNKNOWN
)
&&
(
version_model
.
second
->
state_
!=
ModelReadyState
::
UNAVAILABLE
))
{
live
=
true
;
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
}
if
(
live
)
{
live_model_states
[
model_version
.
first
]
=
std
::
move
(
version_map
);
}
}
return
live_model_states
;
}
Status
ModelLifeCycle
::
StopAllModels
()
{
LOG_VERBOSE
(
2
)
<<
"StopAllModels()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
for
(
auto
&
model_version
:
map_
)
{
for
(
auto
&
version_model
:
model_version
.
second
)
{
if
(
version_model
.
second
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
model_
!=
nullptr
)
{
version_model
.
second
->
model_
->
Stop
();
}
}
}
}
return
Status
::
Success
;
}
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
ModelLifeCycle
::
InflightStatus
()
{
LOG_VERBOSE
(
2
)
<<
"InflightStatus()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
inflight_status
;
for
(
auto
&
model_version
:
map_
)
{
for
(
auto
&
version_model
:
model_version
.
second
)
{
if
(
version_model
.
second
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
model_
!=
nullptr
)
{
const
auto
cnt
=
version_model
.
second
->
model_
->
InflightInferenceCount
();
if
(
cnt
!=
0
)
{
inflight_status
.
emplace
(
model_version
.
first
,
version_model
.
first
,
cnt
);
}
}
}
}
}
return
inflight_status
;
}
const
ModelStateMap
ModelLifeCycle
::
ModelStates
()
{
LOG_VERBOSE
(
2
)
<<
"ModelStates()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
ModelStateMap
model_states
;
for
(
auto
&
model_version
:
map_
)
{
VersionStateMap
version_map
;
for
(
auto
&
version_model
:
model_version
.
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
model_states
[
model_version
.
first
]
=
std
::
move
(
version_map
);
}
return
model_states
;
}
const
VersionStateMap
ModelLifeCycle
::
VersionStates
(
const
std
::
string
&
model_name
)
{
LOG_VERBOSE
(
2
)
<<
"VersionStates() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
VersionStateMap
version_map
;
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
!=
map_
.
end
())
{
for
(
auto
&
version_model
:
mit
->
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
}
return
version_map
;
}
Status
ModelLifeCycle
::
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
)
{
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
!=
map_
.
end
())
{
auto
vit
=
mit
->
second
.
find
(
model_version
);
if
(
vit
!=
mit
->
second
.
end
())
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
vit
->
second
->
mtx_
);
*
state
=
vit
->
second
->
state_
;
return
Status
::
Success
;
}
}
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"model '"
+
model_name
+
"', version "
+
std
::
to_string
(
model_version
)
+
" is not found"
);
}
Status
ModelLifeCycle
::
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
std
::
shared_ptr
<
Model
>*
model
)
{
LOG_VERBOSE
(
2
)
<<
"GetModel() '"
<<
model_name
<<
"' version "
<<
version
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
==
map_
.
end
())
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' is not found"
);
}
auto
vit
=
mit
->
second
.
find
(
version
);
if
(
vit
==
mit
->
second
.
end
())
{
if
(
version
!=
-
1
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' version "
+
std
::
to_string
(
version
)
+
" is not found"
);
}
// The case where the request is asking for latest version
int64_t
latest
=
-
1
;
for
(
auto
&
version_model
:
mit
->
second
)
{
if
(
version_model
.
first
>
latest
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
state_
==
ModelReadyState
::
READY
)
{
latest
=
version_model
.
first
;
// Tedious, but have to set handle for any "latest" version
// at the moment to avoid edge case like the following:
// "versions : 1 3 2", version 3 is latest but is requested
// to be unloaded when the iterator is examining version 2,
// then 'model' will ensure version 3 is still valid
*
model
=
version_model
.
second
->
model_
;
}
}
}
if
(
latest
==
-
1
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' has no available versions"
);
}
}
else
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
vit
->
second
->
mtx_
);
if
(
vit
->
second
->
state_
==
ModelReadyState
::
READY
)
{
*
model
=
vit
->
second
->
model_
;
}
else
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"'"
+
model_name
+
"' version "
+
std
::
to_string
(
version
)
+
" is not at ready state"
);
}
}
return
Status
::
Success
;
}
Status
ModelLifeCycle
::
AsyncUnload
(
const
std
::
string
&
model_name
)
{
LOG_VERBOSE
(
2
)
<<
"AsyncUnload() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
if
(
it
==
map_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Model to be unloaded has not been served"
);
}
// Get the existing agent models and notify the unload action
const
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
for
(
auto
&
version
:
it
->
second
)
{
auto
&
model_info
=
version
.
second
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
model_info
->
last_update_ns_
=
now_ns
;
// Unload serving model, for model that is in LOADING state,
// the updated timestamp will be recognized that there is newer update
// on the model info and the load should be aborted
if
(
model_info
->
state_
==
ModelReadyState
::
READY
)
{
if
(
model_info
->
agent_model_list_
!=
nullptr
)
{
// Only log the error because the model should be unloaded regardless
auto
status
=
model_info
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on TRITONREPOAGENT_ACTION_UNLOAD: "
<<
status
.
AsString
();
}
}
// unload
model_info
->
Release
();
}
}
return
Status
::
Success
;
}
Status
ModelLifeCycle
::
AsyncLoad
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
std
::
shared_ptr
<
TritonRepoAgentModelList
>&
agent_model_list
,
std
::
function
<
void
(
Status
)
>&&
OnComplete
)
{
LOG_VERBOSE
(
2
)
<<
"AsyncLoad() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
if
(
it
==
map_
.
end
())
{
it
=
map_
.
emplace
(
std
::
make_pair
(
model_name
,
VersionMap
())).
first
;
}
std
::
set
<
int64_t
>
versions
;
RETURN_IF_ERROR
(
VersionsToLoad
(
model_path
,
model_name
,
model_config
,
&
versions
));
if
(
versions
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"at least one version must be available under the version policy of "
"model '"
+
model_name
+
"'"
);
}
const
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
std
::
shared_ptr
<
LoadTracker
>
load_tracker
(
new
LoadTracker
(
versions
.
size
(),
now_ns
));
for
(
const
auto
&
version
:
versions
)
{
std
::
unique_ptr
<
ModelInfo
>
linfo
(
new
ModelInfo
(
model_path
,
model_config
,
now_ns
));
ModelInfo
*
model_info
=
linfo
.
get
();
LOG_INFO
<<
"loading: "
<<
model_name
<<
":"
<<
version
;
model_info
->
state_
=
ModelReadyState
::
LOADING
;
model_info
->
state_reason_
.
clear
();
model_info
->
agent_model_list_
=
agent_model_list
;
auto
res
=
it
->
second
.
emplace
(
std
::
make_pair
(
version
,
std
::
unique_ptr
<
ModelInfo
>
()));
if
(
res
.
second
)
{
res
.
first
->
second
=
std
::
move
(
linfo
);
}
else
{
// There is already a record of this model version. Check if the version
// model is being served, if so, the re-load of the version
// should be performed in background to avoid version downtime.
// Otherwise, swap and monitor state for newly loading model.
auto
&
serving_model
=
res
.
first
->
second
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
serving_model
->
mtx_
);
if
(
serving_model
->
state_
==
ModelReadyState
::
READY
)
{
background_models_
[(
uintptr_t
)
model_info
]
=
std
::
move
(
linfo
);
}
else
{
// swap the monitoring model info
serving_model
.
swap
(
linfo
);
// further check the state, put to 'background_models_' to keep
// the object valid if the model is LOADING / UNLOADING, because
// the model info will be accessed by a different thread once the
// operation is completed
if
((
linfo
->
state_
==
ModelReadyState
::
LOADING
)
||
(
linfo
->
state_
==
ModelReadyState
::
UNLOADING
))
{
ModelInfo
*
key
=
linfo
.
get
();
background_models_
[(
uintptr_t
)
key
]
=
std
::
move
(
linfo
);
}
}
}
// Load model asynchronously via thread pool
load_pool_
->
Enqueue
([
this
,
model_name
,
version
,
model_info
,
OnComplete
,
load_tracker
,
is_config_provided
]()
{
CreateModel
(
model_name
,
version
,
model_info
,
is_config_provided
);
OnLoadComplete
(
model_name
,
version
,
model_info
,
OnComplete
,
load_tracker
);
});
}
return
Status
::
Success
;
}
void
ModelLifeCycle
::
CreateModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
const
bool
is_config_provided
)
{
LOG_VERBOSE
(
2
)
<<
"CreateModel() '"
<<
model_name
<<
"' version "
<<
version
;
const
auto
&
model_config
=
model_info
->
model_config_
;
// Create model
Status
status
;
std
::
unique_ptr
<
Model
>
is
;
// If 'backend' is specified in the config then use the new triton
// backend.
if
(
!
model_config
.
backend
().
empty
())
{
std
::
unique_ptr
<
TritonModel
>
model
;
status
=
TritonModel
::
Create
(
server_
,
model_info
->
model_path_
,
cmdline_config_map_
,
host_policy_map_
,
model_name
,
version
,
model_config
,
is_config_provided
,
&
model
);
is
.
reset
(
model
.
release
());
}
else
{
#ifdef TRITON_ENABLE_ENSEMBLE
if
(
model_info
->
is_ensemble_
)
{
status
=
EnsembleModel
::
Create
(
server_
,
model_info
->
model_path_
,
version
,
model_config
,
is_config_provided
,
min_compute_capability_
,
&
is
);
// Complete label provider with label information from involved models
// Must be done here because involved models may not be able to
// obtained from server because this may happen during server
// initialization.
if
(
status
.
IsOk
())
{
std
::
set
<
std
::
string
>
no_label_outputs
;
const
auto
&
label_provider
=
is
->
GetLabelProvider
();
for
(
const
auto
&
output
:
model_config
.
output
())
{
if
(
label_provider
->
GetLabel
(
output
.
name
(),
0
).
empty
())
{
no_label_outputs
.
emplace
(
output
.
name
());
}
}
for
(
const
auto
&
element
:
model_config
.
ensemble_scheduling
().
step
())
{
for
(
const
auto
&
pair
:
element
.
output_map
())
{
// Found model that produce one of the missing output
if
(
no_label_outputs
.
find
(
pair
.
second
)
!=
no_label_outputs
.
end
())
{
std
::
shared_ptr
<
Model
>
model
;
// Safe to obtain model because the ensemble can't be loaded
// until the involved models are ready
GetModel
(
element
.
model_name
(),
element
.
model_version
(),
&
model
);
label_provider
->
AddLabels
(
pair
.
second
,
model
->
GetLabelProvider
()
->
GetLabels
(
pair
.
first
));
}
}
}
}
}
else
#endif // TRITON_ENABLE_ENSEMBLE
{
status
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown platform '"
+
model_config
.
platform
()
+
"'"
);
}
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
if
(
status
.
IsOk
())
{
// [FIXME] better way to manage agent model lifecycle
// Let the deleter also holds a shared pointer copy of agent model list,
// because the reference in ModelInfo can be cleared before the Model object
// is destroyed, and we want agent model to be valid for receiving
// UNLOAD_COMPLETE signal (see ~TritonRepoAgentModelList for detail)
auto
agent_model_list
=
model_info
->
agent_model_list_
;
model_info
->
model_
.
reset
(
is
.
release
(),
ModelDeleter
([
this
,
model_name
,
version
,
model_info
,
agent_model_list
]()
mutable
{
LOG_VERBOSE
(
2
)
<<
"OnDestroy callback() '"
<<
model_name
<<
"' version "
<<
version
;
LOG_INFO
<<
"successfully unloaded '"
<<
model_name
<<
"' version "
<<
version
;
// Update model state as it is fully unloaded
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
model_info
->
state_
=
ModelReadyState
::
UNAVAILABLE
;
model_info
->
state_reason_
=
"unloaded"
;
}
// Check if the model info is in background, if so, remove from the
// map
std
::
lock_guard
<
std
::
mutex
>
lk
(
this
->
map_mtx_
);
auto
it
=
this
->
background_models_
.
find
((
uintptr_t
)
model_info
);
if
(
it
!=
this
->
background_models_
.
end
())
{
this
->
background_models_
.
erase
(
it
);
}
}));
}
else
{
LOG_ERROR
<<
"failed to load '"
<<
model_name
<<
"' version "
<<
version
<<
": "
<<
status
.
AsString
();
model_info
->
state_
=
ModelReadyState
::
UNAVAILABLE
;
model_info
->
state_reason_
=
status
.
AsString
();
}
}
void
ModelLifeCycle
::
OnLoadComplete
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
std
::
function
<
void
(
Status
)
>
OnComplete
,
std
::
shared_ptr
<
LoadTracker
>
load_tracker
)
{
std
::
lock_guard
<
std
::
mutex
>
tracker_lock
(
load_tracker
->
mtx_
);
++
load_tracker
->
completed_version_cnt_
;
load_tracker
->
load_set_
[
version
]
=
model_info
;
// Version will not be marked ready until all versions are
// ready, this simplify the unloading when one version fails to load as
// all other versions won't have inflight requests
if
(
model_info
->
state_
!=
ModelReadyState
::
LOADING
)
{
load_tracker
->
load_failed_
=
true
;
load_tracker
->
reason_
+=
(
"version "
+
std
::
to_string
(
version
)
+
" is at "
+
ModelReadyStateString
(
model_info
->
state_
)
+
" state: "
+
model_info
->
state_reason_
+
";"
);
}
// Check if all versions are completed and finish the load
if
(
load_tracker
->
completed_version_cnt_
==
load_tracker
->
affected_version_cnt_
)
{
// hold 'map_mtx_' as there will be change onto the model info map
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
// Check if the load is the latest frontground action on the model
for
(
const
auto
&
version_info
:
it
->
second
)
{
if
(
version_info
.
second
->
last_update_ns_
>
load_tracker
->
last_update_ns_
)
{
load_tracker
->
load_failed_
=
true
;
load_tracker
->
reason_
=
"Newer operation has been applied to the model lifecycle, current "
"load operation is out-dated."
;
break
;
}
}
if
(
load_tracker
->
load_failed_
)
{
// Move agent list out of ModelInfo as it needs to be invoked
// after all ModelInfos are reset
std
::
shared_ptr
<
TritonRepoAgentModelList
>
lagent_list
;
if
(
model_info
->
agent_model_list_
)
{
lagent_list
=
std
::
move
(
model_info
->
agent_model_list_
);
}
// If any of the versions fails to load, abort the load and unload
// all newly loaded versions
for
(
auto
&
loaded
:
load_tracker
->
load_set_
)
{
// Unload directly, the object is being managed either in frontground
// or background
std
::
lock_guard
<
std
::
mutex
>
lock
(
loaded
.
second
->
mtx_
);
if
(
loaded
.
second
->
model_
!=
nullptr
)
{
loaded
.
second
->
Release
();
}
}
if
(
lagent_list
)
{
auto
status
=
lagent_list
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_LOAD_FAIL
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_FAIL: "
<<
status
.
AsString
();
}
}
}
else
{
// Unload any previous loaded versions that are still available
for
(
auto
&
version_info
:
it
->
second
)
{
auto
&
mi
=
version_info
.
second
;
std
::
lock_guard
<
std
::
mutex
>
info_lk
(
mi
->
mtx_
);
if
((
mi
->
state_
==
ModelReadyState
::
READY
)
&&
(
mi
->
last_update_ns_
<
load_tracker
->
last_update_ns_
))
{
if
(
mi
->
agent_model_list_
!=
nullptr
)
{
auto
status
=
mi
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_UNLOAD: "
<<
status
.
AsString
();
}
}
mi
->
Release
();
}
}
// Mark current versions ready and track info in foreground
for
(
auto
&
loaded
:
load_tracker
->
load_set_
)
{
std
::
lock_guard
<
std
::
mutex
>
curr_info_lk
(
loaded
.
second
->
mtx_
);
loaded
.
second
->
state_
=
ModelReadyState
::
READY
;
model_info
->
state_reason_
.
clear
();
LOG_INFO
<<
"successfully loaded '"
<<
model_name
<<
"' version "
<<
version
;
auto
bit
=
background_models_
.
find
((
uintptr_t
)
loaded
.
second
);
// Check if the version model is loaded in background, if so,
// replace and unload the current serving version
if
(
bit
!=
background_models_
.
end
())
{
auto
vit
=
it
->
second
.
find
(
loaded
.
first
);
// Need to lock the previous model info for in case the model is
// loading / unloading, this ensure the model state is consistent
// even when the load / unload is completed.
std
::
lock_guard
<
std
::
mutex
>
prev_info_lk
(
vit
->
second
->
mtx_
);
// swap previous info into local unique pointer
auto
linfo
=
std
::
move
(
bit
->
second
);
vit
->
second
.
swap
(
linfo
);
background_models_
.
erase
(
bit
);
// if previous info is under change, put into 'background_models_'
if
((
linfo
->
state_
==
ModelReadyState
::
LOADING
)
||
(
linfo
->
state_
==
ModelReadyState
::
UNLOADING
))
{
ModelInfo
*
key
=
linfo
.
get
();
background_models_
[(
uintptr_t
)
key
]
=
std
::
move
(
linfo
);
}
}
}
if
(
model_info
->
agent_model_list_
)
{
auto
status
=
model_info
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE: "
<<
status
.
AsString
();
}
}
}
if
(
OnComplete
!=
nullptr
)
{
OnComplete
(
load_tracker
->
load_failed_
?
Status
(
Status
::
Code
::
INVALID_ARG
,
load_tracker
->
reason_
)
:
Status
::
Success
);
}
}
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_lifecycle.h
0 → 100644
View file @
374c78ca
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "repo_agent.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "triton/common/thread_pool.h"
namespace
triton
{
namespace
core
{
struct
ModelLifeCycleOptions
{
explicit
ModelLifeCycleOptions
(
const
double
min_compute_capability
,
const
triton
::
common
::
BackendCmdlineConfigMap
&
backend_cmdline_config_map
,
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map
,
const
unsigned
int
model_load_thread_count
)
:
min_compute_capability_
(
min_compute_capability
),
backend_cmdline_config_map_
(
backend_cmdline_config_map
),
host_policy_map_
(
host_policy_map
),
model_load_thread_count_
(
model_load_thread_count
)
{
}
// The minimum supported CUDA compute capability.
const
double
min_compute_capability_
;
// The backend configuration settings specified on the command-line
const
triton
::
common
::
BackendCmdlineConfigMap
&
backend_cmdline_config_map_
;
// The host policy setting used when loading models.
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map_
;
// Number of the threads to use for concurrently loading models
const
unsigned
int
model_load_thread_count_
;
};
/// Readiness status for models.
enum
class
ModelReadyState
{
// The model is in an unknown state. The model is not available for
// inferencing.
UNKNOWN
,
// The model is ready and available for inferencing.
READY
,
// The model is unavailable, indicating that the model failed to
// load or has been implicitly or explicitly unloaded. The model is
// not available for inferencing.
UNAVAILABLE
,
// The model is being loaded by the inference server. The model is
// not available for inferencing.
LOADING
,
// The model is being unloaded by the inference server. The model is
// not available for inferencing.
UNLOADING
};
/// Get the string representation for a ModelReadyState
const
std
::
string
&
ModelReadyStateString
(
ModelReadyState
state
);
using
VersionStateMap
=
std
::
map
<
int64_t
,
std
::
pair
<
ModelReadyState
,
std
::
string
>>
;
using
ModelStateMap
=
std
::
map
<
std
::
string
,
VersionStateMap
>
;
// Helper class to manage the lifecycle of a list of associated agent models
class
TritonRepoAgentModelList
{
public:
TritonRepoAgentModelList
()
:
last_action_type_
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
){};
~
TritonRepoAgentModelList
()
{
// Using destructor to finish the unload lifecycle without
// explicitly managing the last step in ModelLifecycle.
if
(
last_action_type_
==
TRITONREPOAGENT_ACTION_UNLOAD
)
{
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
);
}
}
Status
AddAgentModel
(
std
::
unique_ptr
<
TritonRepoAgentModel
>&&
agent_model
)
{
agent_models_
.
emplace_back
(
std
::
move
(
agent_model
));
return
Status
::
Success
;
}
size_t
Size
()
{
return
agent_models_
.
size
();
}
TritonRepoAgentModel
*
Back
()
{
return
agent_models_
.
back
().
get
();
}
Status
InvokeAgentModels
(
const
TRITONREPOAGENT_ActionType
action_type
)
{
// Special handling for the current model lifecycle implementation,
// the repo agent may be asked to perform UNLOAD action multiple times,
// and the requests after the first should be ignored.
const
bool
first_unload
=
(
action_type
==
TRITONREPOAGENT_ACTION_UNLOAD
)
&&
(
last_action_type_
!=
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
first_unload
)
{
return
Status
::
Success
;
}
last_action_type_
=
action_type
;
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
case
TRITONREPOAGENT_ACTION_UNLOAD
:
{
for
(
size_t
idx
=
0
;
idx
<
agent_models_
.
size
();
++
idx
)
{
RETURN_IF_ERROR
(
agent_models_
[
idx
]
->
InvokeAgent
(
action_type
));
}
break
;
}
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
{
// reverse order
for
(
size_t
one_pass_idx
=
agent_models_
.
size
();
one_pass_idx
>
0
;
--
one_pass_idx
)
{
RETURN_IF_ERROR
(
agent_models_
[
one_pass_idx
-
1
]
->
InvokeAgent
(
action_type
));
}
break
;
}
}
return
Status
::
Success
;
}
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentModelList
);
std
::
vector
<
std
::
unique_ptr
<
TritonRepoAgentModel
>>
agent_models_
;
TRITONREPOAGENT_ActionType
last_action_type_
;
};
class
InferenceServer
;
class
Model
;
class
ModelLifeCycle
{
public:
static
Status
Create
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
,
std
::
unique_ptr
<
ModelLifeCycle
>*
life_cycle
);
~
ModelLifeCycle
()
{
// Explicitly clean up thread pool first to clean up any pending callbacks
// that may modify model lifecycle members
load_pool_
.
reset
();
map_
.
clear
();
}
// Start loading model with specified versions asynchronously.
// All versions that are being served will be unloaded only after
// the load is finished sucessfully.
Status
AsyncLoad
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
std
::
shared_ptr
<
TritonRepoAgentModelList
>&
agent_model_list
,
std
::
function
<
void
(
Status
)
>&&
OnComplete
);
// Unload model asynchronously.
Status
AsyncUnload
(
const
std
::
string
&
model_name
);
// Get specified version of the model. Latest ready version will
// be retrieved if 'version' is -1. Return error if the version specified is
// not found or it is not ready.
Status
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
std
::
shared_ptr
<
Model
>*
model
);
// Get the ModelStateMap representation of the live models. A model is
// live if at least one of the versions is not unknown nor unavailable.
// If 'strict_readiness' is true, a model is only live if
// at least one of the versions is ready.
const
ModelStateMap
LiveModelStates
(
bool
strict_readiness
=
false
);
// Get the ModelStateMap representation of the models.
const
ModelStateMap
ModelStates
();
// Get the VersionStateMap representation of the specified model.
const
VersionStateMap
VersionStates
(
const
std
::
string
&
model_name
);
// Get the state of a specific model version.
Status
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
);
// Instruct the model to stop accepting new inference requests.
Status
StopAllModels
();
// Return the number of in-flight inference if any, model versions
// that don't have in-flight inferences will not be included.
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
InflightStatus
();
private:
struct
ModelInfo
{
ModelInfo
(
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
uint64_t
last_update_ns
)
:
model_config_
(
model_config
),
model_path_
(
model_path
),
#ifdef TRITON_ENABLE_ENSEMBLE
is_ensemble_
(
model_config
.
platform
()
==
kEnsemblePlatform
),
#else
is_ensemble_
(
false
),
#endif // TRITON_ENABLE_ENSEMBLE
last_update_ns_
(
last_update_ns
),
state_
(
ModelReadyState
::
UNKNOWN
)
{
}
// Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in
// model state. Note that 'mtx_' should be acquired before invoking this
// function to prevent possible data race.
void
Release
()
{
state_
=
ModelReadyState
::
UNLOADING
;
state_reason_
.
clear
();
agent_model_list_
.
reset
();
model_
.
reset
();
}
const
inference
::
ModelConfig
model_config_
;
const
std
::
string
model_path_
;
const
bool
is_ensemble_
;
std
::
mutex
mtx_
;
uint64_t
last_update_ns_
;
ModelReadyState
state_
;
std
::
string
state_reason_
;
// flyweight
std
::
shared_ptr
<
TritonRepoAgentModelList
>
agent_model_list_
;
std
::
shared_ptr
<
Model
>
model_
;
};
struct
LoadTracker
{
LoadTracker
(
const
size_t
affected_version_cnt
,
const
uint64_t
last_update_ns
)
:
last_update_ns_
(
last_update_ns
),
affected_version_cnt_
(
affected_version_cnt
),
load_failed_
(
false
),
completed_version_cnt_
(
0
)
{
}
const
uint64_t
last_update_ns_
;
const
size_t
affected_version_cnt_
;
std
::
mutex
mtx_
;
bool
load_failed_
;
std
::
string
reason_
;
size_t
completed_version_cnt_
;
std
::
map
<
int64_t
,
ModelInfo
*>
load_set_
;
};
ModelLifeCycle
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
)
:
server_
(
server
),
min_compute_capability_
(
options
.
min_compute_capability_
),
cmdline_config_map_
(
options
.
backend_cmdline_config_map_
),
host_policy_map_
(
options
.
host_policy_map_
)
{
load_pool_
.
reset
(
new
triton
::
common
::
ThreadPool
(
std
::
max
(
1u
,
options
.
model_load_thread_count_
)));
}
void
CreateModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
const
bool
is_config_provided
);
// Callback function template for model load.
// 'OnComplete' needs to be passed by value for now as there can be
// multiple versions to be loaded and each holds a copy of
// the 'OnComplete' callback.
void
OnLoadComplete
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
std
::
function
<
void
(
Status
)
>
OnComplete
,
std
::
shared_ptr
<
LoadTracker
>
load_tracker
);
// Mutex for 'map_' and 'background_models_'
std
::
mutex
map_mtx_
;
using
VersionMap
=
std
::
map
<
int64_t
,
std
::
unique_ptr
<
ModelInfo
>>
;
using
ModelMap
=
std
::
map
<
std
::
string
,
VersionMap
>
;
ModelMap
map_
;
// Models that are being loaded / unloaded in background
std
::
map
<
uintptr_t
,
std
::
unique_ptr
<
ModelInfo
>>
background_models_
;
InferenceServer
*
server_
;
const
double
min_compute_capability_
;
const
triton
::
common
::
BackendCmdlineConfigMap
cmdline_config_map_
;
const
triton
::
common
::
HostPolicyCmdlineConfigMap
host_policy_map_
;
// Fixed-size thread pool to load models at specified concurrency
std
::
unique_ptr
<
triton
::
common
::
ThreadPool
>
load_pool_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_repository_manager.cc
0 → 100644
View file @
374c78ca
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_repository_manager.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "ensemble_utils.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace
triton
{
namespace
core
{
namespace
{
static
std
::
string
file_prefix
=
"file:"
;
// Internal repo agent used for model file override
class
LocalizeRepoAgent
:
public
TritonRepoAgent
{
public:
LocalizeRepoAgent
()
:
TritonRepoAgent
(
"ModelRepositoryManager::LocalizeRepoAgent"
)
{
// Callbacks below interact with TritonRepoAgentModel directly knowing that
// it is the internal implementation of TRITONREPOAGENT_AgentModel
model_action_fn_
=
[](
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ActionType
action_type
)
->
TRITONSERVER_Error
*
{
auto
agent_model
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
{
// localize the override files for model loading,
// as currently the model is expected to load from local directory
const
char
*
temp_dir_cstr
=
nullptr
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
AcquireMutableLocation
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
&
temp_dir_cstr
));
const
std
::
string
temp_dir
=
temp_dir_cstr
;
const
auto
&
files
=
*
reinterpret_cast
<
std
::
vector
<
const
InferenceParameter
*>*>
(
agent_model
->
State
());
bool
found_config
=
false
;
for
(
const
auto
&
file
:
files
)
{
if
(
file
->
Name
()
==
"config"
)
{
if
(
file
->
Type
()
!=
TRITONSERVER_PARAMETER_STRING
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"Config parameter 'config' must have string type for its "
"value"
);
}
inference
::
ModelConfig
config
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
JsonToModelConfig
(
file
->
ValueString
(),
1
/* config_version */
,
&
config
));
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
WriteTextProto
(
JoinPath
({
temp_dir
,
kModelConfigPbTxt
}),
config
));
found_config
=
true
;
}
else
if
(
file
->
Name
().
rfind
(
file_prefix
,
0
)
==
0
)
{
if
(
file
->
Type
()
!=
TRITONSERVER_PARAMETER_BYTES
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
(
std
::
string
(
"File parameter '"
)
+
file
->
Name
()
+
"' must have bytes type for its value"
)
.
c_str
());
}
// Save model file to the instructed directory
// mkdir
const
std
::
string
file_path
=
JoinPath
({
temp_dir
,
file
->
Name
().
substr
(
file_prefix
.
size
())});
const
std
::
string
dir
=
DirName
(
file_path
);
bool
dir_exist
=
false
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
FileExists
(
dir
,
&
dir_exist
));
if
(
dir_exist
)
{
bool
is_dir
=
false
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
IsDirectory
(
dir
,
&
is_dir
));
if
(
!
is_dir
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
(
std
::
string
(
"Invalid file parameter '"
)
+
file
->
Name
()
+
"', directory has been created as a file"
)
.
c_str
());
}
}
else
{
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
MakeDirectory
(
dir
,
true
/* recursive */
));
}
// write
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
WriteBinaryFile
(
file_path
,
reinterpret_cast
<
const
char
*>
(
file
->
ValuePointer
()),
file
->
ValueByteSize
()));
}
}
if
(
!
found_config
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"Load parameter 'config' must be specified for model file "
"override"
);
}
// Commit the temporary directory
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
SetLocation
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
temp_dir_cstr
));
break
;
}
default:
break
;
}
return
nullptr
;
// success
};
model_fini_fn_
=
[](
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
)
->
TRITONSERVER_Error
*
{
auto
agent_model
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
DeleteMutableLocation
());
return
nullptr
;
// success
};
}
};
Status
CreateAgentModelListWithLoadAction
(
const
inference
::
ModelConfig
&
original_model_config
,
const
std
::
string
&
original_model_path
,
std
::
shared_ptr
<
TritonRepoAgentModelList
>*
agent_model_list
)
{
if
(
original_model_config
.
has_model_repository_agents
())
{
// Trick to append user specified repo agent on top of internal ones
std
::
shared_ptr
<
TritonRepoAgentModelList
>
lagent_model_list
;
if
(
*
agent_model_list
!=
nullptr
)
{
lagent_model_list
=
std
::
move
(
*
agent_model_list
);
}
else
{
lagent_model_list
.
reset
(
new
TritonRepoAgentModelList
());
}
FileSystemType
filesystem_type
;
RETURN_IF_ERROR
(
GetFileSystemType
(
original_model_path
,
&
filesystem_type
));
TRITONREPOAGENT_ArtifactType
artifact_type
=
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
;
if
(
filesystem_type
!=
FileSystemType
::
LOCAL
)
{
artifact_type
=
TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM
;
}
const
char
*
location
=
original_model_path
.
c_str
();
inference
::
ModelConfig
model_config
=
original_model_config
;
for
(
const
auto
&
agent_config
:
original_model_config
.
model_repository_agents
().
agents
())
{
std
::
shared_ptr
<
TritonRepoAgent
>
agent
;
RETURN_IF_ERROR
(
TritonRepoAgentManager
::
CreateAgent
(
agent_config
.
name
(),
&
agent
));
TritonRepoAgent
::
Parameters
agent_params
;
for
(
const
auto
&
parameter
:
agent_config
.
parameters
())
{
agent_params
.
emplace_back
(
parameter
.
first
,
parameter
.
second
);
}
std
::
unique_ptr
<
TritonRepoAgentModel
>
agent_model
;
if
(
lagent_model_list
->
Size
()
!=
0
)
{
lagent_model_list
->
Back
()
->
Location
(
&
artifact_type
,
&
location
);
const
auto
config_path
=
JoinPath
({
location
,
kModelConfigPbTxt
});
if
(
!
ReadTextProto
(
config_path
,
&
model_config
).
IsOk
())
{
model_config
.
Clear
();
}
}
RETURN_IF_ERROR
(
TritonRepoAgentModel
::
Create
(
artifact_type
,
location
,
model_config
,
agent
,
agent_params
,
&
agent_model
));
RETURN_IF_ERROR
(
agent_model
->
InvokeAgent
(
TRITONREPOAGENT_ACTION_LOAD
));
lagent_model_list
->
AddAgentModel
(
std
::
move
(
agent_model
));
}
*
agent_model_list
=
std
::
move
(
lagent_model_list
);
}
return
Status
::
Success
;
}
int64_t
GetModifiedTime
(
const
std
::
string
&
path
)
{
// If there is an error in any step the fall-back default
// modification time is 0. This means that in error cases 'path'
// will show as not modified. This is the safe fall-back to avoid
// assuming a model is constantly being modified.
bool
path_is_dir
;
Status
status
=
IsDirectory
(
path
,
&
path_is_dir
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
// If 'path' is a file return its mtime. Otherwise, using the modification
// time of the directory as baseline in case of file deletion
int64_t
mtime
=
0
;
status
=
FileModificationTime
(
path
,
&
mtime
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
if
(
!
path_is_dir
)
{
return
mtime
;
}
// 'path' is a directory. Return the most recent mtime of the
// contents of the directory.
std
::
set
<
std
::
string
>
contents
;
status
=
GetDirectoryContents
(
path
,
&
contents
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
for
(
const
auto
&
child
:
contents
)
{
const
auto
full_path
=
JoinPath
({
path
,
child
});
mtime
=
std
::
max
(
mtime
,
GetModifiedTime
(
full_path
));
}
return
mtime
;
}
// Return true if any file in the subdirectory root at 'path' has been
// modified more recently than 'last'. Return the most-recent modified
// time in 'last'.
bool
IsModified
(
const
std
::
string
&
path
,
int64_t
*
last_ns
)
{
const
int64_t
repo_ns
=
GetModifiedTime
(
path
);
bool
modified
=
repo_ns
>
*
last_ns
;
*
last_ns
=
repo_ns
;
return
modified
;
}
}
// namespace
struct
ModelRepositoryManager
::
ModelInfo
{
ModelInfo
(
const
int64_t
mtime_nsec
,
const
int64_t
prev_mtime_ns
,
const
std
::
string
&
model_path
)
:
mtime_nsec_
(
mtime_nsec
),
prev_mtime_ns_
(
prev_mtime_ns
),
explicitly_load_
(
true
),
model_path_
(
model_path
),
is_config_provided_
(
false
)
{
}
ModelInfo
()
:
mtime_nsec_
(
0
),
prev_mtime_ns_
(
0
),
explicitly_load_
(
true
),
is_config_provided_
(
false
)
{
}
int64_t
mtime_nsec_
;
int64_t
prev_mtime_ns_
;
bool
explicitly_load_
;
inference
::
ModelConfig
model_config_
;
std
::
string
model_path_
;
// Temporary location to hold agent model list before creating the model
// the ownership must transfer to ModelLifeCycle to ensure
// the agent model life cycle is handled properly.
std
::
shared_ptr
<
TritonRepoAgentModelList
>
agent_model_list_
;
bool
is_config_provided_
;
};
ModelRepositoryManager
::
ModelRepositoryManager
(
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
bool
autofill
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
)
:
repository_paths_
(
repository_paths
),
autofill_
(
autofill
),
polling_enabled_
(
polling_enabled
),
model_control_enabled_
(
model_control_enabled
),
min_compute_capability_
(
min_compute_capability
),
model_life_cycle_
(
std
::
move
(
life_cycle
))
{
}
ModelRepositoryManager
::~
ModelRepositoryManager
()
{}
Status
ModelRepositoryManager
::
Create
(
InferenceServer
*
server
,
const
std
::
string
&
server_version
,
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
std
::
set
<
std
::
string
>&
startup_models
,
const
bool
strict_model_config
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
ModelLifeCycleOptions
&
life_cycle_options
,
std
::
unique_ptr
<
ModelRepositoryManager
>*
model_repository_manager
)
{
// The rest only matters if repository path is valid directory
for
(
const
auto
&
path
:
repository_paths
)
{
bool
path_is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
path_is_dir
));
if
(
!
path_is_dir
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"repository path is not a valid directory"
);
}
}
if
(
polling_enabled
&&
model_control_enabled
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"cannot enable both polling and explicit model control"
);
}
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
;
RETURN_IF_ERROR
(
ModelLifeCycle
::
Create
(
server
,
life_cycle_options
,
&
life_cycle
));
// Not setting the smart pointer directly to simplify clean up
std
::
unique_ptr
<
ModelRepositoryManager
>
local_manager
(
new
ModelRepositoryManager
(
repository_paths
,
!
strict_model_config
,
polling_enabled
,
model_control_enabled
,
life_cycle_options
.
min_compute_capability_
,
std
::
move
(
life_cycle
)));
*
model_repository_manager
=
std
::
move
(
local_manager
);
// Support loading all models on startup in explicit model control mode with
// special startup_model name "*". This does not imply support for pattern
// matching in model names.
bool
load_all_models_on_startup
=
false
;
if
((
startup_models
.
find
(
"*"
)
!=
startup_models
.
end
())
&&
model_control_enabled
)
{
if
(
startup_models
.
size
()
>
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Wildcard model name '*' must be the ONLY startup model "
"if specified at all."
);
}
load_all_models_on_startup
=
true
;
}
bool
all_models_polled
=
true
;
if
(
!
model_control_enabled
||
load_all_models_on_startup
)
{
// only error happens before model load / unload will be return
// model loading / unloading error will be printed but ignored
RETURN_IF_ERROR
(
(
*
model_repository_manager
)
->
PollAndUpdateInternal
(
&
all_models_polled
));
}
else
{
// Load each specified startup_model
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
models
;
for
(
const
auto
&
model_name
:
startup_models
)
{
models
[
model_name
];
}
RETURN_IF_ERROR
(
(
*
model_repository_manager
)
->
LoadUnloadModels
(
models
,
ActionType
::
LOAD
,
false
,
&
all_models_polled
));
}
if
(
!
all_models_polled
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
// Some models may failed to be loaded after model manager is created,
// return proper error and let function caller decide whether to proceed.
for
(
const
auto
&
model
:
(
*
model_repository_manager
)
->
infos_
)
{
const
auto
version_states
=
(
*
model_repository_manager
)
->
model_life_cycle_
->
VersionStates
(
model
.
first
);
// Return general error message, detail of each model's loading state
// is logged separately.
if
(
version_states
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
for
(
const
auto
&
state
:
version_states
)
{
if
(
state
.
second
.
first
!=
ModelReadyState
::
READY
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
PollAndUpdate
()
{
if
(
!
polling_enabled_
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"polling is disabled"
);
}
bool
all_models_polled
;
return
PollAndUpdateInternal
(
&
all_models_polled
);
}
Status
ModelRepositoryManager
::
PollAndUpdateInternal
(
bool
*
all_models_polled
)
{
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
std
::
set
<
std
::
string
>
added
,
deleted
,
modified
,
unmodified
;
// We don't modify 'infos_' in place to minimize how long we need to
// hold the lock and also prevent any partial changes to do an error
// during processing.
ModelInfoMap
new_infos
;
// Each subdirectory of repository path is a model directory from
// which we read the model configuration.
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
subdirs
;
RETURN_IF_ERROR
(
Poll
(
subdirs
,
&
added
,
&
deleted
,
&
modified
,
&
unmodified
,
&
new_infos
,
all_models_polled
));
// Anything in 'infos_' that is not in "added", "modified", or
// "unmodified" is deleted.
for
(
const
auto
&
pr
:
infos_
)
{
if
((
added
.
find
(
pr
.
first
)
==
added
.
end
())
&&
(
modified
.
find
(
pr
.
first
)
==
modified
.
end
())
&&
(
unmodified
.
find
(
pr
.
first
)
==
unmodified
.
end
()))
{
deleted
.
insert
(
pr
.
first
);
}
}
// Nothing to do if no model adds, deletes or modifies.
if
(
added
.
empty
()
&&
deleted
.
empty
()
&&
modified
.
empty
())
{
return
Status
::
Success
;
}
infos_
.
swap
(
new_infos
);
UpdateDependencyGraph
(
added
,
deleted
,
modified
);
for
(
const
auto
&
name
:
deleted
)
{
model_life_cycle_
->
AsyncUnload
(
name
);
}
// model loading / unloading error will be printed but ignored
LoadModelByDependency
();
return
Status
::
Success
;
}
std
::
map
<
std
::
string
,
Status
>
ModelRepositoryManager
::
LoadModelByDependency
()
{
std
::
map
<
std
::
string
,
Status
>
res
;
struct
ModelState
{
ModelState
(
DependencyNode
*
node
)
:
node_
(
node
),
status_
(
Status
::
Success
)
{}
DependencyNode
*
node_
;
Status
status_
;
std
::
promise
<
void
>
ready_
;
};
NodeSet
loaded_models
;
auto
set_pair
=
ModelsToLoadUnload
(
loaded_models
);
// Loop until all model are loaded / unloaded
while
((
!
set_pair
.
first
.
empty
())
||
(
!
set_pair
.
second
.
empty
()))
{
loaded_models
.
clear
();
// Unload invalid models first
for
(
auto
&
invalid_model
:
set_pair
.
second
)
{
model_life_cycle_
->
AsyncUnload
(
invalid_model
->
model_name_
);
LOG_ERROR
<<
invalid_model
->
status_
.
AsString
();
invalid_model
->
loaded_versions_
=
std
::
set
<
int64_t
>
();
loaded_models
.
emplace
(
invalid_model
);
}
// load valid models and wait for load results
std
::
vector
<
std
::
unique_ptr
<
ModelState
>>
model_states
;
for
(
auto
&
valid_model
:
set_pair
.
first
)
{
model_states
.
emplace_back
(
new
ModelState
(
valid_model
));
auto
model_state
=
model_states
.
back
().
get
();
const
auto
itr
=
infos_
.
find
(
valid_model
->
model_name_
);
auto
status
=
model_life_cycle_
->
AsyncLoad
(
valid_model
->
model_name_
,
itr
->
second
->
model_path_
,
valid_model
->
model_config_
,
itr
->
second
->
is_config_provided_
,
itr
->
second
->
agent_model_list_
,
[
model_state
](
Status
load_status
)
{
model_state
->
status_
=
load_status
;
model_state
->
ready_
.
set_value
();
});
if
(
!
status
.
IsOk
())
{
model_state
->
status_
=
status
;
model_state
->
ready_
.
set_value
();
LOG_ERROR
<<
"failed to load model '"
<<
valid_model
->
model_name_
<<
"': "
<<
status
.
Message
();
}
loaded_models
.
emplace
(
valid_model
);
}
for
(
auto
&
model_state
:
model_states
)
{
model_state
->
ready_
.
get_future
().
wait
();
res
[
model_state
->
node_
->
model_name_
]
=
model_state
->
status_
;
const
auto
version_state
=
model_life_cycle_
->
VersionStates
(
model_state
->
node_
->
model_name_
);
model_state
->
node_
->
loaded_versions_
.
clear
();
for
(
const
auto
&
vs
:
version_state
)
{
if
(
vs
.
second
.
first
==
ModelReadyState
::
READY
)
{
model_state
->
node_
->
loaded_versions_
.
emplace
(
vs
.
first
);
}
}
// If the model failed to load, should revert the timestamp to
// ensure the next load request will attempt to load the model again
// for operation consistency.
if
(
!
model_state
->
status_
.
IsOk
())
{
auto
&
model_info
=
infos_
.
find
(
model_state
->
node_
->
model_name_
)
->
second
;
model_info
->
mtime_nsec_
=
model_info
->
prev_mtime_ns_
;
}
}
set_pair
=
ModelsToLoadUnload
(
loaded_models
);
}
// Clear temporary stored agent model list after all loads are triggerred
for
(
auto
&
info
:
infos_
)
{
info
.
second
->
agent_model_list_
.
reset
();
}
return
res
;
}
Status
ModelRepositoryManager
::
LoadUnloadModel
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"explicit model load / unload is not allowed if polling is enabled"
);
}
if
(
models
.
size
()
>
1
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"explicit load / unload multiple models is not currently supported"
);
}
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
bool
polled
=
true
;
RETURN_IF_ERROR
(
LoadUnloadModels
(
models
,
type
,
unload_dependents
,
&
polled
));
// Check if model is loaded / unloaded properly
const
auto
&
model_name
=
models
.
begin
()
->
first
;
if
(
!
polled
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', failed to poll from model repository"
);
}
const
auto
version_states
=
model_life_cycle_
->
VersionStates
(
model_name
);
if
(
type
==
ActionType
::
LOAD
)
{
if
(
version_states
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', no version is available"
);
}
auto
it
=
infos_
.
find
(
model_name
);
if
(
it
==
infos_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', failed to poll from model repository"
);
}
}
else
{
std
::
string
ready_version_str
;
for
(
const
auto
&
version_state
:
version_states
)
{
if
(
version_state
.
second
.
first
==
ModelReadyState
::
READY
)
{
ready_version_str
+=
std
::
to_string
(
version_state
.
first
);
ready_version_str
+=
","
;
}
}
if
(
!
ready_version_str
.
empty
())
{
ready_version_str
.
pop_back
();
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to unload '"
+
model_name
+
"', versions that are still available: "
+
ready_version_str
);
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
LoadUnloadModels
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
,
bool
*
all_models_polled
)
{
auto
status
=
Status
::
Success
;
*
all_models_polled
=
true
;
// Update ModelInfo related to file system accordingly
std
::
set
<
std
::
string
>
added
,
deleted
,
modified
,
unmodified
;
{
if
(
type
==
ActionType
::
UNLOAD
)
{
for
(
const
auto
&
model
:
models
)
{
deleted
.
insert
(
model
.
first
);
}
}
// ActionType::LOAD and in model control mode
else
{
std
::
set
<
std
::
string
>
checked_models
;
auto
current_models
=
models
;
for
(
const
auto
&
model
:
models
)
{
checked_models
.
emplace
(
model
.
first
);
}
ModelInfoMap
new_infos
;
#ifdef TRITON_ENABLE_ENSEMBLE
bool
first_iteration
=
true
;
#endif // TRITON_ENABLE_ENSEMBLE
while
(
!
current_models
.
empty
())
{
bool
polled
=
true
;
RETURN_IF_ERROR
(
Poll
(
current_models
,
&
added
,
&
deleted
,
&
modified
,
&
unmodified
,
&
new_infos
,
&
polled
));
*
all_models_polled
&=
polled
;
// More models should be polled if the polled models are ensembles
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
next_models
;
#ifdef TRITON_ENABLE_ENSEMBLE
for
(
const
auto
&
model
:
current_models
)
{
auto
it
=
new_infos
.
find
(
model
.
first
);
// Some models may be marked as deleted and not in 'new_infos'
if
(
it
!=
new_infos
.
end
())
{
it
->
second
->
explicitly_load_
=
first_iteration
;
const
auto
&
config
=
it
->
second
->
model_config_
;
if
(
config
.
has_ensemble_scheduling
())
{
for
(
const
auto
&
step
:
config
.
ensemble_scheduling
().
step
())
{
bool
need_poll
=
checked_models
.
emplace
(
step
.
model_name
()).
second
;
if
(
need_poll
)
{
next_models
[
step
.
model_name
()];
}
}
}
}
}
first_iteration
=
false
;
#endif // TRITON_ENABLE_ENSEMBLE
current_models
.
swap
(
next_models
);
}
// Only update the infos when all validation is completed
for
(
const
auto
&
model_name
:
added
)
{
auto
nitr
=
new_infos
.
find
(
model_name
);
infos_
.
emplace
(
model_name
,
std
::
move
(
nitr
->
second
));
}
for
(
const
auto
&
model_name
:
modified
)
{
auto
nitr
=
new_infos
.
find
(
model_name
);
auto
itr
=
infos_
.
find
(
model_name
);
itr
->
second
=
std
::
move
(
nitr
->
second
);
}
}
}
std
::
set
<
std
::
string
>
deleted_dependents
;
// Update dependency graph and load
UpdateDependencyGraph
(
added
,
deleted
,
modified
,
unload_dependents
?
&
deleted_dependents
:
nullptr
);
// The models are in 'deleted' either when they are asked to be unloaded or
// they are not found / are duplicated across all model repositories.
// In all cases, should unload them and remove from 'infos_' explicitly.
for
(
const
auto
&
name
:
(
unload_dependents
?
deleted_dependents
:
deleted
))
{
infos_
.
erase
(
name
);
model_life_cycle_
->
AsyncUnload
(
name
);
}
// load / unload the models affected, and check the load status of
// the requested models
const
auto
&
load_status
=
LoadModelByDependency
();
if
(
status
.
IsOk
()
&&
(
type
==
ActionType
::
LOAD
))
{
std
::
string
load_error_message
=
""
;
for
(
const
auto
&
model
:
models
)
{
auto
it
=
load_status
.
find
(
model
.
first
);
// If 'model.first' not in load status, it means the (re-)load is not
// necessary because there is no change in the model's directory
if
((
it
!=
load_status
.
end
())
&&
!
it
->
second
.
IsOk
())
{
load_error_message
+=
(
"load failed for model '"
+
model
.
first
+
"': "
+
it
->
second
.
Message
()
+
"
\n
"
);
}
}
if
(
!
load_error_message
.
empty
())
{
status
=
Status
(
Status
::
Code
::
INVALID_ARG
,
load_error_message
);
}
}
return
status
;
}
Status
ModelRepositoryManager
::
UnloadAllModels
()
{
Status
status
;
for
(
const
auto
&
name_info
:
infos_
)
{
Status
unload_status
=
model_life_cycle_
->
AsyncUnload
(
name_info
.
first
);
if
(
!
unload_status
.
IsOk
())
{
status
=
Status
(
unload_status
.
ErrorCode
(),
"Failed to gracefully unload models: "
+
unload_status
.
Message
());
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
StopAllModels
()
{
return
model_life_cycle_
->
StopAllModels
();
}
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
ModelRepositoryManager
::
InflightStatus
()
{
return
model_life_cycle_
->
InflightStatus
();
}
const
ModelStateMap
ModelRepositoryManager
::
LiveModelStates
(
bool
strict_readiness
)
{
return
model_life_cycle_
->
LiveModelStates
(
strict_readiness
);
}
const
ModelStateMap
ModelRepositoryManager
::
ModelStates
()
{
return
model_life_cycle_
->
ModelStates
();
}
const
VersionStateMap
ModelRepositoryManager
::
VersionStates
(
const
std
::
string
&
model_name
)
{
return
model_life_cycle_
->
VersionStates
(
model_name
);
}
Status
ModelRepositoryManager
::
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
)
{
return
model_life_cycle_
->
ModelState
(
model_name
,
model_version
,
state
);
}
Status
ModelRepositoryManager
::
RepositoryIndex
(
const
bool
ready_only
,
std
::
vector
<
ModelIndex
>*
index
)
{
std
::
set
<
std
::
string
>
seen_models
;
std
::
set
<
std
::
string
>
duplicate_models
;
for
(
const
auto
&
repository_path
:
repository_paths_
)
{
// For any mapped models in this repository, save the mapping
// from their subdirectory name to model name.
std
::
map
<
std
::
string
,
std
::
string
>
models_in_repo
;
for
(
const
auto
&
mapping_it
:
model_mappings_
)
{
if
(
mapping_it
.
second
.
first
==
repository_path
)
{
models_in_repo
.
emplace
(
BaseName
(
mapping_it
.
second
.
second
),
mapping_it
.
first
);
}
}
std
::
set
<
std
::
string
>
subdirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
repository_path
,
&
subdirs
));
for
(
const
auto
&
subdir
:
subdirs
)
{
auto
model
=
subdir
;
auto
model_it
=
models_in_repo
.
find
(
subdir
);
if
(
model_it
!=
models_in_repo
.
end
())
{
model
=
model_it
->
second
;
}
if
(
seen_models
.
find
(
model
)
!=
seen_models
.
end
())
{
duplicate_models
.
insert
(
model
);
}
seen_models
.
insert
(
model
);
}
}
ModelStateMap
states
=
ModelStates
();
for
(
const
auto
&
model
:
seen_models
)
{
// If the same model appears in multiple repostories then show it
// as unavailable since duplicate models are not allowed to load.
if
(
duplicate_models
.
find
(
model
)
!=
duplicate_models
.
end
())
{
index
->
emplace_back
(
model
,
-
1
/* version */
,
ModelReadyState
::
UNAVAILABLE
,
MODEL_READY_REASON_DUPLICATE
);
continue
;
}
// If there is any version/state/reason associated with the model
// then include that in the index.
auto
sitr
=
states
.
find
(
model
);
if
(
sitr
==
states
.
end
())
{
if
(
!
ready_only
)
{
index
->
emplace_back
(
model
);
}
}
else
{
for
(
const
auto
&
pr
:
sitr
->
second
)
{
if
(
!
ready_only
||
(
pr
.
second
.
first
==
ModelReadyState
::
READY
))
{
index
->
emplace_back
(
model
,
pr
.
first
,
pr
.
second
.
first
,
pr
.
second
.
second
);
}
}
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
std
::
shared_ptr
<
Model
>*
model
)
{
Status
status
=
model_life_cycle_
->
GetModel
(
model_name
,
model_version
,
model
);
if
(
!
status
.
IsOk
())
{
model
->
reset
();
status
=
Status
(
status
.
ErrorCode
(),
"Request for unknown model: "
+
status
.
Message
());
}
return
status
;
}
Status
ModelRepositoryManager
::
Poll
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
std
::
set
<
std
::
string
>*
added
,
std
::
set
<
std
::
string
>*
deleted
,
std
::
set
<
std
::
string
>*
modified
,
std
::
set
<
std
::
string
>*
unmodified
,
ModelInfoMap
*
updated_infos
,
bool
*
all_models_polled
)
{
*
all_models_polled
=
true
;
// empty path is the special case to indicate the model should be loaded
// from override file content in 'models'.
std
::
map
<
std
::
string
,
std
::
string
>
model_to_path
;
// If no model is specified, poll all models in all model repositories.
// Otherwise, only poll the specified models
if
(
models
.
empty
())
{
std
::
set
<
std
::
string
>
duplicated_models
;
for
(
const
auto
&
repository_path
:
repository_paths_
)
{
std
::
set
<
std
::
string
>
subdirs
;
Status
status
=
GetDirectorySubdirs
(
repository_path
,
&
subdirs
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll model repository '"
<<
repository_path
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
else
{
for
(
const
auto
&
subdir
:
subdirs
)
{
if
(
!
model_to_path
.
emplace
(
subdir
,
JoinPath
({
repository_path
,
subdir
}))
.
second
)
{
duplicated_models
.
insert
(
subdir
);
*
all_models_polled
=
false
;
}
}
}
}
// If the model is not unique, mark as deleted to unload it
for
(
const
auto
&
model
:
duplicated_models
)
{
model_to_path
.
erase
(
model
);
deleted
->
insert
(
model
);
LOG_ERROR
<<
"failed to poll model '"
<<
model
<<
"': not unique across all model repositories"
;
}
}
// If models are specified, this is explicit model control mode.
else
{
for
(
const
auto
&
model
:
models
)
{
// Skip repository polling if override model files
if
(
ModelDirectoryOverride
(
model
.
second
))
{
model_to_path
.
emplace
(
model
.
first
,
""
);
continue
;
}
// Check model mapping first to see if matching model to load.
bool
exists
=
false
;
auto
model_it
=
model_mappings_
.
find
(
model
.
first
);
if
(
model_it
!=
model_mappings_
.
end
())
{
bool
exists_in_this_repo
=
false
;
auto
full_path
=
model_it
->
second
.
second
;
Status
status
=
FileExists
(
full_path
,
&
exists_in_this_repo
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll mapped path '"
<<
full_path
<<
"' for model '"
<<
model
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
if
(
exists_in_this_repo
)
{
model_to_path
.
emplace
(
model
.
first
,
model_it
->
second
.
second
);
exists
=
true
;
}
else
{
LOG_ERROR
<<
"mapped path '"
<<
full_path
<<
"' does not exist for model '"
<<
model
.
first
<<
"'"
;
exists
=
false
;
}
}
else
{
for
(
const
auto
repository_path
:
repository_paths_
)
{
bool
exists_in_this_repo
=
false
;
const
auto
full_path
=
JoinPath
({
repository_path
,
model
.
first
});
Status
status
=
FileExists
(
full_path
,
&
exists_in_this_repo
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll model repository '"
<<
repository_path
<<
"' for model '"
<<
model
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
else
if
(
exists_in_this_repo
)
{
// Check to make sure this directory is not mapped.
// If mapped, continue to next repository path.
bool
mapped
=
false
;
for
(
auto
const
&
mapping
:
model_mappings_
)
{
if
(
mapping
.
second
.
second
==
full_path
)
{
mapped
=
true
;
break
;
}
}
if
(
mapped
)
{
continue
;
}
auto
res
=
model_to_path
.
emplace
(
model
.
first
,
JoinPath
({
repository_path
,
model
.
first
}));
if
(
res
.
second
)
{
exists
=
true
;
}
else
{
exists
=
false
;
model_to_path
.
erase
(
res
.
first
);
LOG_ERROR
<<
"failed to poll model '"
<<
model
.
first
<<
"': not unique across all model repositories"
;
break
;
}
}
}
}
// For an explicitly specified model that doesn't exist, we don't mark it
// as deleted, we simply mark that we couldn't poll all models.
if
(
!
exists
)
{
*
all_models_polled
=
false
;
}
}
}
// Poll each of the models. If error happens during polling the model,
// its state will fallback to the state before the polling.
for
(
const
auto
&
pair
:
model_to_path
)
{
std
::
unique_ptr
<
ModelInfo
>
model_info
;
const
auto
&
mit
=
models
.
find
(
pair
.
first
);
static
std
::
vector
<
const
InferenceParameter
*>
empty_params
;
auto
status
=
InitializeModelInfo
(
pair
.
first
,
pair
.
second
,
((
mit
==
models
.
end
())
?
empty_params
:
mit
->
second
),
&
model_info
);
const
auto
&
iitr
=
infos_
.
find
(
pair
.
first
);
const
bool
invalid_add
=
(
!
status
.
IsOk
())
&&
(
iitr
==
infos_
.
end
());
if
(
!
invalid_add
)
{
const
auto
&
ret
=
updated_infos
->
emplace
(
pair
.
first
,
nullptr
);
if
(
!
ret
.
second
)
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"unexpected model info for model '"
+
pair
.
first
+
"'"
);
}
// Classify load state and set updated info
if
(
model_info
==
nullptr
)
{
ret
.
first
->
second
.
reset
(
new
ModelInfo
(
*
iitr
->
second
));
unmodified
->
insert
(
pair
.
first
);
}
else
{
ret
.
first
->
second
=
std
::
move
(
model_info
);
if
(
iitr
!=
infos_
.
end
())
{
modified
->
insert
(
pair
.
first
);
}
else
{
added
->
insert
(
pair
.
first
);
}
}
}
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Poll failed for model directory '"
<<
pair
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
}
return
Status
::
Success
;
}
bool
ModelRepositoryManager
::
ModelDirectoryOverride
(
const
std
::
vector
<
const
InferenceParameter
*>&
model_params
)
{
for
(
const
auto
&
param
:
model_params
)
{
if
(
param
->
Name
().
rfind
(
file_prefix
,
0
)
==
0
)
{
// param name starts with prefix if user provides override file
return
true
;
}
}
return
false
;
}
Status
ModelRepositoryManager
::
InitializeModelInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
path
,
const
std
::
vector
<
const
InferenceParameter
*>&
params
,
std
::
unique_ptr
<
ModelInfo
>*
info
)
{
std
::
unique_ptr
<
ModelInfo
>
linfo
(
new
ModelInfo
());
linfo
->
model_path_
=
path
;
bool
unmodified
=
false
;
const
auto
iitr
=
infos_
.
find
(
name
);
// Set 'prev_mtime_ns_' if there is existing ModelInfo
if
(
iitr
!=
infos_
.
end
())
{
linfo
->
prev_mtime_ns_
=
iitr
->
second
->
mtime_nsec_
;
}
else
{
linfo
->
prev_mtime_ns_
=
0
;
}
// Set 'mtime_nsec_' and override 'model_path_' if current path is empty
// (file override is specified)
if
(
linfo
->
model_path_
.
empty
())
{
// Need to localize the override files, use repo agent to manage
// the lifecycle of the localized files
std
::
shared_ptr
<
TritonRepoAgent
>
localize_agent
(
new
LocalizeRepoAgent
());
std
::
unique_ptr
<
TritonRepoAgentModel
>
localize_agent_model
;
RETURN_IF_ERROR
(
TritonRepoAgentModel
::
Create
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
""
,
inference
::
ModelConfig
(),
localize_agent
,
{},
&
localize_agent_model
));
// Set agent model state so the repo agent can access the encoded files
// Using const_cast here but we are safe as the RepoAgent will not
// modify the state
localize_agent_model
->
SetState
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
&
params
)));
RETURN_IF_ERROR
(
localize_agent_model
->
InvokeAgent
(
TRITONREPOAGENT_ACTION_LOAD
));
const
char
*
location
;
TRITONREPOAGENT_ArtifactType
type
;
RETURN_IF_ERROR
(
localize_agent_model
->
Location
(
&
type
,
&
location
));
// For file override, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo
->
mtime_nsec_
=
0
;
linfo
->
model_path_
=
location
;
linfo
->
agent_model_list_
.
reset
(
new
TritonRepoAgentModelList
());
linfo
->
agent_model_list_
->
AddAgentModel
(
std
::
move
(
localize_agent_model
));
}
else
{
if
(
iitr
==
infos_
.
end
())
{
linfo
->
mtime_nsec_
=
GetModifiedTime
(
std
::
string
(
linfo
->
model_path_
));
}
else
{
// Check the current timestamps to determine if model actually has been
// modified
linfo
->
mtime_nsec_
=
linfo
->
prev_mtime_ns_
;
unmodified
=
!
IsModified
(
std
::
string
(
linfo
->
model_path_
),
&
linfo
->
mtime_nsec_
);
}
}
// Set 'model_config_'
bool
parsed_config
=
false
;
// Check if there is config override
for
(
const
auto
&
override_parameter
:
params
)
{
if
((
override_parameter
->
Name
()
==
"config"
)
&&
(
override_parameter
->
Type
()
==
TRITONSERVER_PARAMETER_STRING
))
{
// When override happens, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo
->
mtime_nsec_
=
0
;
unmodified
=
false
;
const
std
::
string
&
override_config
=
override_parameter
->
ValueString
();
auto
err
=
JsonToModelConfig
(
override_config
,
1
/* config_version */
,
&
linfo
->
model_config_
);
if
(
!
err
.
IsOk
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Invalid config override: "
+
std
::
string
(
err
.
Message
()));
}
parsed_config
=
true
;
break
;
}
else
if
(
override_parameter
->
Name
().
rfind
(
file_prefix
,
0
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Unrecognized load parameter '"
+
override_parameter
->
Name
()
+
"' with type '"
+
TRITONSERVER_ParameterTypeString
(
override_parameter
->
Type
())
+
"'"
);
}
}
// Polling model is considered unmodified by this point and can be returned
// with info == nullptr
if
(
unmodified
)
{
return
Status
::
Success
;
}
// Create the associated repo agent models when a model is to be loaded,
// this must be done before normalizing model config as agents might
// redirect to use the model config at a different location
if
(
!
parsed_config
)
{
const
auto
config_path
=
JoinPath
({
linfo
->
model_path_
,
kModelConfigPbTxt
});
bool
model_config_exists
=
false
;
RETURN_IF_ERROR
(
FileExists
(
config_path
,
&
model_config_exists
));
// model config can be missing if auto fill is set
if
(
autofill_
&&
!
model_config_exists
)
{
linfo
->
model_config_
.
Clear
();
}
else
{
RETURN_IF_ERROR
(
ReadTextProto
(
config_path
,
&
linfo
->
model_config_
));
parsed_config
=
true
;
}
}
if
(
parsed_config
)
{
RETURN_IF_ERROR
(
CreateAgentModelListWithLoadAction
(
linfo
->
model_config_
,
linfo
->
model_path_
,
&
linfo
->
agent_model_list_
));
if
(
linfo
->
agent_model_list_
!=
nullptr
)
{
// Get the latest repository path
const
char
*
location
;
TRITONREPOAGENT_ArtifactType
artifact_type
;
RETURN_IF_ERROR
(
linfo
->
agent_model_list_
->
Back
()
->
Location
(
&
artifact_type
,
&
location
));
auto
latest_path
=
std
::
string
(
location
);
linfo
->
model_path_
=
latest_path
;
}
}
linfo
->
is_config_provided_
=
parsed_config
;
// Try to automatically generate missing parts of the model
// configuration (autofill) that don't require model detail
RETURN_IF_ERROR
(
GetNormalizedModelConfig
(
name
,
linfo
->
model_path_
,
min_compute_capability_
,
&
linfo
->
model_config_
));
// Note that the model inputs and outputs are not validated until
// the model model is intialized as they may not be auto-completed
// until model is intialized.
RETURN_IF_ERROR
(
ValidateModelConfig
(
linfo
->
model_config_
,
min_compute_capability_
));
if
(
!
autofill_
)
{
RETURN_IF_ERROR
(
ValidateModelIOConfig
(
linfo
->
model_config_
));
}
// If the model is mapped, update its config name based on the
// mapping.
if
(
model_mappings_
.
find
(
name
)
!=
model_mappings_
.
end
())
{
linfo
->
model_config_
.
set_name
(
name
);
}
else
{
// If there is no model mapping, make sure the name of the model
// matches the name of the directory. This is a somewhat arbitrary
// requirement but seems like good practice to require it of the user.
// It also acts as a check to make sure we don't have two different
// models with the same name.
if
(
linfo
->
model_config_
.
name
()
!=
name
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected directory name '"
+
name
+
"' for model '"
+
linfo
->
model_config_
.
name
()
+
"', directory name must equal model name"
);
}
}
*
info
=
std
::
move
(
linfo
);
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
UpdateDependencyGraph
(
const
std
::
set
<
std
::
string
>&
added
,
const
std
::
set
<
std
::
string
>&
deleted
,
const
std
::
set
<
std
::
string
>&
modified
,
std
::
set
<
std
::
string
>*
deleted_dependents
)
{
// update dependency graph, if the state of a node is changed, all its
// downstreams will be affected
// deleted, drop from dependency_graph, add to missing_nodes if downstreams is
// not empty affected_nodes are all ensembles as only ensembles are depending
// on other models
std
::
set
<
DependencyNode
*>
affected_nodes
;
std
::
set
<
DependencyNode
*>
updated_nodes
;
std
::
set
<
std
::
string
>
current_deleted
=
deleted
;
while
(
!
current_deleted
.
empty
())
{
std
::
set
<
std
::
string
>
next_deleted
;
for
(
const
auto
&
model_name
:
current_deleted
)
{
auto
it
=
dependency_graph_
.
find
(
model_name
);
if
(
it
!=
dependency_graph_
.
end
())
{
// remove this node from its upstreams
for
(
auto
&
upstream
:
it
->
second
->
upstreams_
)
{
upstream
.
first
->
downstreams_
.
erase
(
it
->
second
.
get
());
// Check if the upstream should be removed as well
if
((
deleted_dependents
!=
nullptr
)
&&
(
upstream
.
first
->
downstreams_
.
empty
())
&&
(
!
upstream
.
first
->
explicitly_load_
))
{
next_deleted
.
emplace
(
upstream
.
first
->
model_name_
);
}
}
it
->
second
->
upstreams_
.
clear
();
if
(
!
it
->
second
->
downstreams_
.
empty
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
// mark this node as missing upstream in its downstreams
for
(
auto
&
downstream
:
it
->
second
->
downstreams_
)
{
downstream
->
missing_upstreams_
.
emplace
(
it
->
second
.
get
());
}
missing_nodes_
.
emplace
(
std
::
make_pair
(
model_name
,
std
::
move
(
it
->
second
)));
}
// Make sure deleted node will not be in affected nodes
affected_nodes
.
erase
(
it
->
second
.
get
());
dependency_graph_
.
erase
(
it
);
}
if
(
deleted_dependents
!=
nullptr
)
{
deleted_dependents
->
emplace
(
model_name
);
}
}
current_deleted
.
swap
(
next_deleted
);
}
// modified, invalidate (uncheck) all downstreams
for
(
const
auto
&
model_name
:
modified
)
{
auto
it
=
dependency_graph_
.
find
(
model_name
);
if
(
it
!=
dependency_graph_
.
end
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
ModelInfo
*
info
=
nullptr
;
GetModelInfo
(
model_name
,
&
info
);
it
->
second
->
model_config_
=
info
->
model_config_
;
it
->
second
->
explicitly_load_
=
info
->
explicitly_load_
;
// remove this node from its upstream node
for
(
auto
&
upstream
:
it
->
second
->
upstreams_
)
{
upstream
.
first
->
downstreams_
.
erase
(
it
->
second
.
get
());
}
it
->
second
->
upstreams_
.
clear
();
it
->
second
->
checked_
=
false
;
it
->
second
->
status_
=
Status
::
Success
;
updated_nodes
.
emplace
(
it
->
second
.
get
());
}
}
// added, add to dependency_graph, if in missing_node, invalidate (uncheck)
// and associate all downstreams, remove from missing_node
for
(
const
auto
&
model_name
:
added
)
{
std
::
unique_ptr
<
DependencyNode
>
added_node
;
auto
it
=
missing_nodes_
.
find
(
model_name
);
if
(
it
!=
missing_nodes_
.
end
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
// remove this node from missing upstream node in its downstream nodes
for
(
auto
&
downstream
:
it
->
second
->
downstreams_
)
{
downstream
->
missing_upstreams_
.
erase
(
it
->
second
.
get
());
}
it
->
second
->
checked_
=
false
;
added_node
=
std
::
move
(
it
->
second
);
missing_nodes_
.
erase
(
it
);
}
else
{
// Right now, nothing is going to be filled until validation
added_node
.
reset
(
new
DependencyNode
(
model_name
));
}
ModelInfo
*
info
=
nullptr
;
GetModelInfo
(
model_name
,
&
info
);
added_node
->
model_config_
=
info
->
model_config_
;
added_node
->
explicitly_load_
=
info
->
explicitly_load_
;
updated_nodes
.
emplace
(
added_node
.
get
());
dependency_graph_
.
emplace
(
std
::
make_pair
(
model_name
,
std
::
move
(
added_node
)));
}
auto
&
affected_ensembles
=
affected_nodes
;
for
(
auto
&
updated_node
:
updated_nodes
)
{
bool
is_ensemble
=
ConnectDependencyGraph
(
updated_node
);
if
(
is_ensemble
)
{
affected_ensembles
.
emplace
(
updated_node
);
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// After the dependency graph is updated, check ensemble dependencies
for
(
auto
&
ensemble
:
affected_ensembles
)
{
if
(
ensemble
->
status_
.
IsOk
())
{
if
(
!
ensemble
->
missing_upstreams_
.
empty
())
{
std
::
string
name_list
;
for
(
auto
it
=
ensemble
->
missing_upstreams_
.
begin
();
it
!=
ensemble
->
missing_upstreams_
.
end
();
it
++
)
{
if
(
it
!=
ensemble
->
missing_upstreams_
.
begin
())
{
name_list
+=
", "
;
}
name_list
+=
(
*
it
)
->
model_name_
;
}
ensemble
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble "
+
ensemble
->
model_name_
+
" contains models that are not available: "
+
name_list
);
}
else
{
ensemble
->
status_
=
CircularcyCheck
(
ensemble
,
ensemble
);
}
}
}
#endif // TRITON_ENABLE_ENSEMBLE
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
RegisterModelRepository
(
const
std
::
string
&
repository
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
model_mapping
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"repository registration is not allowed if model control mode is not "
"EXPLICIT"
);
}
bool
is_directory
=
false
;
auto
status
=
IsDirectory
(
repository
,
&
is_directory
);
if
(
!
status
.
IsOk
()
||
!
is_directory
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"failed to register '"
)
+
repository
+
"', repository not found"
)
.
c_str
());
}
{
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
// Check repository and mapped models do not yet exist.
if
(
repository_paths_
.
find
(
repository
)
!=
repository_paths_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"model repository '"
+
repository
+
"' has already been registered"
);
}
for
(
const
auto
&
mapping
:
model_mapping
)
{
if
(
model_mappings_
.
find
(
mapping
.
first
)
!=
model_mappings_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
(
std
::
string
(
"failed to register '"
)
+
mapping
.
first
+
"', there is a conflicting mapping for '"
+
std
::
string
(
mapping
.
first
)
+
"'"
)
.
c_str
());
}
}
repository_paths_
.
emplace
(
repository
);
for
(
const
auto
&
mapping
:
model_mapping
)
{
model_mappings_
.
emplace
(
mapping
.
first
,
std
::
make_pair
(
repository
,
JoinPath
({
repository
,
mapping
.
second
})));
}
}
LOG_INFO
<<
"Model repository registered: "
<<
repository
;
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
UnregisterModelRepository
(
const
std
::
string
&
repository
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"repository unregistration is not allowed if model control mode is not "
"EXPLICIT"
);
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
if
(
repository_paths_
.
erase
(
repository
)
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"failed to unregister '"
+
repository
+
"', repository not found"
);
}
std
::
set
<
std
::
string
>
models_to_delete
;
for
(
auto
const
&
mapping
:
model_mappings_
)
{
if
(
mapping
.
second
.
first
==
repository
)
{
models_to_delete
.
insert
(
mapping
.
first
);
}
}
for
(
auto
const
&
model
:
models_to_delete
)
{
model_mappings_
.
erase
(
model
);
}
}
LOG_INFO
<<
"Model repository unregistered: "
<<
repository
;
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
CircularcyCheck
(
DependencyNode
*
current_node
,
const
DependencyNode
*
start_node
)
{
for
(
auto
&
downstream
:
current_node
->
downstreams_
)
{
if
(
downstream
->
model_name_
==
start_node
->
model_name_
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"circular dependency between ensembles: "
+
start_node
->
model_name_
+
" -> ... -> "
+
current_node
->
model_name_
+
" -> "
+
start_node
->
model_name_
);
}
else
{
const
auto
status
=
CircularcyCheck
(
downstream
,
start_node
);
if
(
!
status
.
IsOk
()
&&
current_node
->
status_
.
IsOk
())
{
current_node
->
status_
=
status
;
return
status
;
}
}
}
return
Status
::
Success
;
}
void
ModelRepositoryManager
::
UncheckDownstream
(
NodeSet
*
downstreams
,
NodeSet
*
updated_nodes
)
{
// Mark downstream nodes as unchecked recursively
for
(
auto
&
node
:
*
downstreams
)
{
if
(
node
->
checked_
)
{
node
->
checked_
=
false
;
node
->
status_
=
Status
::
Success
;
UncheckDownstream
(
&
node
->
downstreams_
,
updated_nodes
);
updated_nodes
->
emplace
(
node
);
}
}
}
bool
ModelRepositoryManager
::
ConnectDependencyGraph
(
DependencyNode
*
updated_node
)
{
// Check the node's model config to determine if it depends on other models
// and if those models are present
updated_node
->
upstreams_
.
clear
();
updated_node
->
missing_upstreams_
.
clear
();
if
(
updated_node
->
model_config_
.
has_ensemble_scheduling
())
{
for
(
const
auto
&
step
:
updated_node
->
model_config_
.
ensemble_scheduling
().
step
())
{
DependencyNode
*
upstream_node
=
nullptr
;
const
auto
&
model_name
=
step
.
model_name
();
auto
dit
=
dependency_graph_
.
find
(
model_name
);
if
(
dit
==
dependency_graph_
.
end
())
{
auto
mit
=
missing_nodes_
.
find
(
model_name
);
if
(
mit
==
missing_nodes_
.
end
())
{
std
::
unique_ptr
<
DependencyNode
>
node
(
new
DependencyNode
(
model_name
));
updated_node
->
missing_upstreams_
.
emplace
(
node
.
get
());
mit
=
missing_nodes_
.
emplace
(
model_name
,
std
::
move
(
node
)).
first
;
}
// Add the node to missing node's downstream so that when the missing
// node is added, the downstreams can be found easily.
mit
->
second
->
downstreams_
.
emplace
(
updated_node
);
upstream_node
=
mit
->
second
.
get
();
}
else
{
dit
->
second
->
downstreams_
.
emplace
(
updated_node
);
upstream_node
=
dit
->
second
.
get
();
}
auto
res
=
updated_node
->
upstreams_
.
emplace
(
upstream_node
,
std
::
set
<
int64_t
>
({
step
.
model_version
()}));
// If map insertion doesn't happen, the same model is required in
// different step, insert the version to existing required version set.
if
(
!
res
.
second
)
{
res
.
first
->
second
.
insert
(
step
.
model_version
());
}
}
return
true
;
}
return
false
;
}
Status
ModelRepositoryManager
::
GetModelInfo
(
const
std
::
string
&
name
,
ModelInfo
**
model_info
)
{
const
auto
itr
=
infos_
.
find
(
name
);
if
(
itr
==
infos_
.
end
())
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"no configuration for model '"
+
name
+
"'"
);
}
*
model_info
=
itr
->
second
.
get
();
return
Status
::
Success
;
}
std
::
pair
<
ModelRepositoryManager
::
NodeSet
,
ModelRepositoryManager
::
NodeSet
>
ModelRepositoryManager
::
ModelsToLoadUnload
(
const
NodeSet
&
loaded_models
)
{
// <valid model set, invalid model set>
std
::
pair
<
NodeSet
,
NodeSet
>
res
;
// first call to this function
if
(
loaded_models
.
empty
())
{
for
(
auto
&
pair
:
dependency_graph_
)
{
auto
node
=
pair
.
second
.
get
();
// only care about nodes that are affected by the update
if
(
!
node
->
checked_
)
{
if
(
CheckNode
(
node
))
{
if
(
node
->
status_
.
IsOk
())
{
res
.
first
.
emplace
(
node
);
}
else
{
res
.
second
.
emplace
(
node
);
}
}
}
}
}
else
{
for
(
const
auto
&
model
:
loaded_models
)
{
for
(
auto
node
:
model
->
downstreams_
)
{
// only care about nodes that are affected by the update
if
(
!
node
->
checked_
)
{
if
(
CheckNode
(
node
))
{
if
(
node
->
status_
.
IsOk
())
{
res
.
first
.
emplace
(
node
);
}
else
{
res
.
second
.
emplace
(
node
);
}
}
}
}
}
}
for
(
auto
&
node
:
res
.
first
)
{
node
->
checked_
=
true
;
}
for
(
auto
&
node
:
res
.
second
)
{
node
->
checked_
=
true
;
}
return
res
;
}
bool
ModelRepositoryManager
::
CheckNode
(
DependencyNode
*
node
)
{
bool
node_ready
=
true
;
// if the node is in invalid status, mark as ready as we know
// it should not be loaded
if
(
node
->
status_
.
IsOk
())
{
for
(
auto
&
upstream
:
node
->
upstreams_
)
{
if
(
!
upstream
.
first
->
checked_
)
{
node_ready
=
false
;
break
;
}
if
(
!
upstream
.
first
->
status_
.
IsOk
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' which is not valid"
);
}
else
if
(
upstream
.
first
->
loaded_versions_
.
empty
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' which has no loaded version"
);
}
else
{
for
(
const
auto
&
required_version
:
upstream
.
second
)
{
if
(
required_version
==
-
1
)
{
continue
;
}
auto
it
=
upstream
.
first
->
loaded_versions_
.
find
(
required_version
);
if
(
it
==
upstream
.
first
->
loaded_versions_
.
end
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' whose required version "
+
std
::
to_string
(
required_version
)
+
" is not loaded"
);
}
}
}
if
(
!
node
->
status_
.
IsOk
())
{
break
;
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// Validate ensemble config if the node is ready. By this point, the
// depending models are loaded and their configs are completed
if
(
node_ready
&&
node
->
status_
.
IsOk
())
{
node
->
status_
=
ValidateEnsembleConfig
(
this
,
node
);
}
#endif // TRITON_ENABLE_ENSEMBLE
}
return
node_ready
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_repository_manager.h
0 → 100644
View file @
374c78ca
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include <set>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "model_lifecycle.h"
#include "status.h"
#include "triton/common/model_config.h"
namespace
triton
{
namespace
core
{
class
InferenceServer
;
class
Model
;
// [FIXME] should have separated load / unload functions for clarity
enum
ActionType
{
NO_ACTION
,
LOAD
,
UNLOAD
};
/// Predefined reason strings
#define MODEL_READY_REASON_DUPLICATE "model appears in two or more repositories"
/// An object to manage the model repository active in the server.
class
ModelRepositoryManager
{
public:
// Index information for a model.
struct
ModelIndex
{
ModelIndex
(
const
std
::
string
&
n
)
:
name_only_
(
true
),
name_
(
n
),
version_
(
-
1
),
state_
(
ModelReadyState
::
UNKNOWN
)
{
}
ModelIndex
(
const
std
::
string
&
n
,
const
int64_t
v
,
const
ModelReadyState
s
,
const
std
::
string
&
r
)
:
name_only_
(
false
),
name_
(
n
),
version_
(
v
),
state_
(
s
),
reason_
(
r
)
{
}
const
bool
name_only_
;
const
std
::
string
name_
;
const
int64_t
version_
;
const
ModelReadyState
state_
;
const
std
::
string
reason_
;
};
/// A basic unit in dependency graph that records the models seen by the model
/// repository manager.
struct
DependencyNode
{
DependencyNode
(
const
std
::
string
&
model_name
)
:
model_name_
(
model_name
),
status_
(
Status
::
Success
),
checked_
(
false
)
{
}
std
::
string
model_name_
;
Status
status_
;
bool
checked_
;
bool
explicitly_load_
;
inference
::
ModelConfig
model_config_
;
std
::
set
<
int64_t
>
loaded_versions_
;
std
::
set
<
DependencyNode
*>
missing_upstreams_
;
std
::
unordered_map
<
DependencyNode
*
,
std
::
set
<
int64_t
>>
upstreams_
;
std
::
set
<
DependencyNode
*>
downstreams_
;
};
~
ModelRepositoryManager
();
/// Create a manager for a repository.
/// \param server The pointer to the inference server.
/// \param server_version The version of the inference server.
/// \param repository_paths A set of file-system paths of the repositories.
/// \param startup_models A set of models to be loaded at startup
/// if model control is enabled.
/// \param strict_model_config If false attempt to autofill missing required
/// information in each model configuration.
/// \param polling_enabled If true, then PollAndUpdate() is allowed.
/// Otherwise, it is not allowed.
/// \param model_control_enabled If true, then LoadUnloadModel() is allowed
/// and the models in the model repository will not be loaded at startup.
/// Otherwise, LoadUnloadModel() is not allowed and the models will be loaded.
/// Cannot be set to true if polling_enabled is true.
/// \param life_cycle_options The options to configure ModelLifeCycle.
/// \param model_repository_manager Return the model repository manager.
/// \return The error status.
static
Status
Create
(
InferenceServer
*
server
,
const
std
::
string
&
server_version
,
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
std
::
set
<
std
::
string
>&
startup_models
,
const
bool
strict_model_config
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
ModelLifeCycleOptions
&
life_cycle_options
,
std
::
unique_ptr
<
ModelRepositoryManager
>*
model_repository_manager
);
/// Poll the model repository to determine the new set of models and
/// compare with the current set. And serve the new set of models based
/// on their version policy.
Status
PollAndUpdate
();
/// Load or unload a specified model.
/// \param models The models and the parameters to be loaded or unloaded
/// \param type The type action to be performed. If the action is LOAD and
/// the model has been loaded, the model will be re-loaded.
/// \return error status. Return "NOT_FOUND" if it tries to load
/// a non-existing model or if it tries to unload a model that hasn't been
/// loaded.
Status
LoadUnloadModel
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
);
/// Unload all models. This function should be called before shutting down
/// the model repository manager.
/// \return error status.
Status
UnloadAllModels
();
/// Instruct all models to stop accepting new inference requests. However,
/// the models are still capable of processing inference requests
/// if the model considers them as part of the in-flight inference.
/// \return error status.
Status
StopAllModels
();
/// \return the number of in-flight inferences for the all versions of all
/// models. The set element will be a tuple of <model_name, model_version,
/// in-flight inference count>. Note that a model version will not be included
/// if it doesn't have in-flight inferences.
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
InflightStatus
();
/// \param strict_readiness If true, only models that have at least one
/// ready version will be considered as live. Otherwise, the models that
/// have loading / unloading versions will also be live.
/// \return the state of all versions of all live models.
const
ModelStateMap
LiveModelStates
(
bool
strict_readiness
=
false
);
/// \return the state of all versions of all models that have every
/// been (attempted) loaded over the lifetime of the server.
const
ModelStateMap
ModelStates
();
/// \return the states of all versions of a specific model.
const
VersionStateMap
VersionStates
(
const
std
::
string
&
model_name
);
/// \return the ready-state of a specific model version.
Status
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
);
/// Get the index of all models in all repositories.
/// \param ready_only If true return only index of models that are ready.
/// \param index Returns the index.
/// \return error status.
Status
RepositoryIndex
(
const
bool
ready_only
,
std
::
vector
<
ModelIndex
>*
index
);
/// Obtain the specified model.
/// \param model_name The name of the model.
/// \param model_version The version of the model.
/// \param model Return the model object.
/// \return error status.
Status
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
std
::
shared_ptr
<
Model
>*
model
);
// Register model repository path.
/// \param repository Path to model repository.
/// \param model_mapping Mapping with (overridden) model name as key, subdir
/// name as value.
/// \return error status
Status
RegisterModelRepository
(
const
std
::
string
&
repository
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
model_mapping
);
// Unregister model repository path.
/// \param repository Path to model repository.
/// \return error status
Status
UnregisterModelRepository
(
const
std
::
string
&
repository
);
private:
struct
ModelInfo
;
// Map from model name to information about the model.
using
ModelInfoMap
=
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
ModelInfo
>>
;
// Set of DependencyNode
using
NodeSet
=
std
::
set
<
DependencyNode
*>
;
ModelRepositoryManager
(
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
bool
autofill
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
);
/// The internal function that are called in Create() and PollAndUpdate().
Status
PollAndUpdateInternal
(
bool
*
all_models_polled
);
/// The internal function that load or unload a set of models.
Status
LoadUnloadModels
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
,
bool
*
all_models_polled
);
/// Poll the requested models in the model repository and
/// compare with the current set. Return the additions, deletions,
/// and modifications that have occurred. This function will not updated
/// the current model info, it is caller's responsibility to do so.
/// \param models The map from models to be polled to their associated
/// parameters.
/// \param added The names of the models added to the repository.
/// \param deleted The names of the models removed from the repository.
/// \param modified The names of the models remaining in the
/// repository that have been changed.
/// \param unmodified The names of the models remaining in the
/// repository that have not changed.
/// \param updated_infos The model infos retrieved from the poll.
/// \param all_models_polled Return true if all models are polled and
/// their model configuration are validated successfully. Instead of aborting
/// the polling, the models that fail will be ignored and their model infos
/// will stay in the previous state.
/// \return The error status.
Status
Poll
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
std
::
set
<
std
::
string
>*
added
,
std
::
set
<
std
::
string
>*
deleted
,
std
::
set
<
std
::
string
>*
modified
,
std
::
set
<
std
::
string
>*
unmodified
,
ModelInfoMap
*
updated_infos
,
bool
*
all_models_polled
);
/// Helper function for Poll() to initialize ModelInfo for the model.
/// \param name The name of the model.
/// \param path The model path. Empty path means the model is provided via
/// 'params'
/// \param params The model parameters provided for polling model.
/// \param info Return the updated ModelInfo. 'nullptr' will be returned if
/// existing ModelInfo for the model should be reused.
/// \return The error status.
Status
InitializeModelInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
path
,
const
std
::
vector
<
const
InferenceParameter
*>&
params
,
std
::
unique_ptr
<
ModelInfo
>*
info
);
/// Load models based on the dependency graph. The function will iteratively
/// load models that all the models they depend on has been loaded, and unload
/// models if their dependencies are no longer satisfied.
/// \return The status of the model loads.
std
::
map
<
std
::
string
,
Status
>
LoadModelByDependency
();
/// Helper function to update the dependency graph based on the poll result
/// \param added The names of the models added to the repository.
/// \param deleted The names of the models removed from the repository.
/// \param modified The names of the models remaining in the
/// repository that have been changed.
/// \param deleted_dependents The names of dependent models to be removed
/// from the repository.
/// \return The error status.
Status
UpdateDependencyGraph
(
const
std
::
set
<
std
::
string
>&
added
,
const
std
::
set
<
std
::
string
>&
deleted
,
const
std
::
set
<
std
::
string
>&
modified
,
std
::
set
<
std
::
string
>*
deleted_dependents
=
nullptr
);
/// Helper function to uncheck the nodes because the model that they depends
/// on has changed. The unchecked nodes will be validated again.
/// The function will be call recursively to uncheck all downstreams.
/// \param downstreams The nodes to be unchecked.
/// \param updated_nodes Return the nodes that have been unchecked
void
UncheckDownstream
(
NodeSet
*
downstreams
,
NodeSet
*
updated_nodes
);
/// Helper function to construct the edges between nodes in dependency graph.
/// \param updated_node The node that is newly added or modified.
/// \return True if the node represents an ensemble model. False otherwise.
bool
ConnectDependencyGraph
(
DependencyNode
*
updated_node
);
/// Get the model info for a named model.
/// \param name The model name.
/// \param model_info Returns the model information.
/// \return OK if found, NOT_FOUND otherwise.
Status
GetModelInfo
(
const
std
::
string
&
name
,
ModelInfo
**
model_info
);
/// Get the models to be loaded / unloaded based on the model loaded in
/// previous iteration.
/// \param loaded_models The models loaded / unloaded in previous iteration.
/// Unloaded models will be represented as models with no loaded versions.
/// \return A pair of node set containing models to be loaded and models to be
/// unloaded for the next iteration.
std
::
pair
<
NodeSet
,
NodeSet
>
ModelsToLoadUnload
(
const
NodeSet
&
loaded_models
);
/// Check if the node is ready for the next iteration. A node is ready if the
/// node is invalid (containing invalid model config or its depdencies failed
/// to load) or all of its dependencies are satisfied.
/// \param node The node to be checked.
/// \return True if the node is ready. False otherwise.
bool
CheckNode
(
DependencyNode
*
node
);
Status
CircularcyCheck
(
DependencyNode
*
current_node
,
const
DependencyNode
*
start_node
);
bool
ModelDirectoryOverride
(
const
std
::
vector
<
const
InferenceParameter
*>&
model_params
);
std
::
set
<
std
::
string
>
repository_paths_
;
const
bool
autofill_
;
const
bool
polling_enabled_
;
const
bool
model_control_enabled_
;
const
double
min_compute_capability_
;
std
::
mutex
poll_mu_
;
ModelInfoMap
infos_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
DependencyNode
>>
dependency_graph_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
DependencyNode
>>
missing_nodes_
;
// Mappings from (overridden) model names to a pair of their repository and
// absolute path
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
std
::
string
,
std
::
string
>>
model_mappings_
;
std
::
unique_ptr
<
ModelLifeCycle
>
model_life_cycle_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/numa_utils.cc
0 → 100644
View file @
374c78ca
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "numa_utils.h"
#ifndef _WIN32
#include <numa.h>
#include <numaif.h>
#endif
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
namespace
{
std
::
string
VectorToString
(
const
std
::
vector
<
int
>&
vec
)
{
std
::
string
str
(
"["
);
for
(
const
auto
&
element
:
vec
)
{
str
+=
std
::
to_string
(
element
);
str
+=
","
;
}
str
+=
"]"
;
return
str
;
}
Status
ParseIntOption
(
const
std
::
string
&
msg
,
const
std
::
string
&
arg
,
int
*
value
)
{
try
{
*
value
=
std
::
stoi
(
arg
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
msg
+
": Can't parse '"
+
arg
+
"' to integer"
);
}
return
Status
::
Success
;
}
}
// namespace
// NUMA setting will be ignored on Windows platform
#ifdef _WIN32
Status
SetNumaConfigOnThread
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
return
Status
::
Success
;
}
Status
SetNumaMemoryPolicy
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
return
Status
::
Success
;
}
Status
GetNumaMemoryPolicyNodeMask
(
unsigned
long
*
node_mask
)
{
*
node_mask
=
0
;
return
Status
::
Success
;
}
Status
ResetNumaMemoryPolicy
()
{
return
Status
::
Success
;
}
Status
SetNumaThreadAffinity
(
std
::
thread
::
native_handle_type
thread
,
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
return
Status
::
Success
;
}
#else
// Use variable to make sure no NUMA related function is actually called
// if Triton is not running with NUMA awareness. i.e. Extra docker permission
// is needed to call the NUMA functions and this ensures backward compatibility.
thread_local
bool
numa_set
=
false
;
Status
SetNumaConfigOnThread
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
// Set thread affinity
RETURN_IF_ERROR
(
SetNumaThreadAffinity
(
pthread_self
(),
host_policy
));
// Set memory policy
RETURN_IF_ERROR
(
SetNumaMemoryPolicy
(
host_policy
));
return
Status
::
Success
;
}
Status
SetNumaMemoryPolicy
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
const
auto
it
=
host_policy
.
find
(
"numa-node"
);
if
(
it
!=
host_policy
.
end
())
{
int
node_id
;
RETURN_IF_ERROR
(
ParseIntOption
(
"Parsing 'numa-node' value"
,
it
->
second
,
&
node_id
));
LOG_VERBOSE
(
1
)
<<
"Thread is binding to NUMA node "
<<
it
->
second
<<
". Max NUMA node count: "
<<
(
numa_max_node
()
+
1
);
numa_set
=
true
;
unsigned
long
node_mask
=
1UL
<<
node_id
;
if
(
set_mempolicy
(
MPOL_BIND
,
&
node_mask
,
(
numa_max_node
()
+
1
)
+
1
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to set NUMA memory policy: "
)
+
strerror
(
errno
));
}
}
return
Status
::
Success
;
}
Status
GetNumaMemoryPolicyNodeMask
(
unsigned
long
*
node_mask
)
{
*
node_mask
=
0
;
int
mode
;
if
(
numa_set
&&
get_mempolicy
(
&
mode
,
node_mask
,
numa_max_node
()
+
1
,
NULL
,
0
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to get NUMA node for current thread: "
)
+
strerror
(
errno
));
}
return
Status
::
Success
;
}
Status
ResetNumaMemoryPolicy
()
{
if
(
numa_set
&&
(
set_mempolicy
(
MPOL_DEFAULT
,
nullptr
,
0
)
!=
0
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to reset NUMA memory policy: "
)
+
strerror
(
errno
));
}
numa_set
=
false
;
return
Status
::
Success
;
}
Status
SetNumaThreadAffinity
(
std
::
thread
::
native_handle_type
thread
,
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
const
auto
it
=
host_policy
.
find
(
"cpu-cores"
);
if
(
it
!=
host_policy
.
end
())
{
// Parse CPUs
std
::
vector
<
int
>
cpus
;
{
const
auto
&
cpu_str
=
it
->
second
;
auto
delim_cpus
=
cpu_str
.
find
(
","
);
int
current_pos
=
0
;
while
(
true
)
{
auto
delim_range
=
cpu_str
.
find
(
"-"
,
current_pos
);
if
(
delim_range
==
std
::
string
::
npos
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
std
::
string
(
"host policy setting 'cpu-cores' format is "
"'<lower_cpu_core_id>-<upper_cpu_core_id>'. Got "
)
+
cpu_str
.
substr
(
current_pos
,
((
delim_cpus
==
std
::
string
::
npos
)
?
(
cpu_str
.
length
()
+
1
)
:
delim_cpus
)
-
current_pos
));
}
int
lower
,
upper
;
RETURN_IF_ERROR
(
ParseIntOption
(
"Parsing 'cpu-cores' value"
,
cpu_str
.
substr
(
current_pos
,
delim_range
-
current_pos
),
&
lower
));
RETURN_IF_ERROR
(
ParseIntOption
(
"Parsing 'cpu-cores' value"
,
(
delim_cpus
==
std
::
string
::
npos
)
?
cpu_str
.
substr
(
delim_range
+
1
)
:
cpu_str
.
substr
(
delim_range
+
1
,
delim_cpus
-
(
delim_range
+
1
)),
&
upper
));
for
(;
lower
<=
upper
;
++
lower
)
{
cpus
.
push_back
(
lower
);
}
// break if the processed range is the last specified range
if
(
delim_cpus
!=
std
::
string
::
npos
)
{
current_pos
=
delim_cpus
+
1
;
delim_cpus
=
cpu_str
.
find
(
","
,
current_pos
);
}
else
{
break
;
}
}
}
LOG_VERBOSE
(
1
)
<<
"Thread is binding to one of the CPUs: "
<<
VectorToString
(
cpus
);
numa_set
=
true
;
cpu_set_t
cpuset
;
CPU_ZERO
(
&
cpuset
);
for
(
int
cpu
:
cpus
)
{
CPU_SET
(
cpu
,
&
cpuset
);
}
if
(
pthread_setaffinity_np
(
thread
,
sizeof
(
cpu_set_t
),
&
cpuset
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to set NUMA thread affinity: "
)
+
strerror
(
errno
));
}
}
return
Status
::
Success
;
}
#endif
}}
// namespace triton::core
3rdparty/core-r22.12/src/numa_utils.h
0 → 100644
View file @
374c78ca
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <map>
#include <thread>
#include <vector>
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
// Helper function to set memory policy and thread affinity on current thread
Status
SetNumaConfigOnThread
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
);
// Restrict the memory allocation to specific NUMA node.
Status
SetNumaMemoryPolicy
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
);
// Retrieve the node mask used to set memory policy for the current thread
Status
GetNumaMemoryPolicyNodeMask
(
unsigned
long
*
node_mask
);
// Reset the memory allocation setting.
Status
ResetNumaMemoryPolicy
();
// Set a thread affinity to be on specific cpus.
Status
SetNumaThreadAffinity
(
std
::
thread
::
native_handle_type
thread
,
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/payload.cc
0 → 100644
View file @
374c78ca
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "payload.h"
namespace
triton
{
namespace
core
{
Payload
::
Payload
()
:
op_type_
(
Operation
::
INFER_RUN
),
requests_
(
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>
()),
OnCallback_
([]()
{}),
instance_
(
nullptr
),
state_
(
State
::
UNINITIALIZED
),
batcher_start_ns_
(
0
),
saturated_
(
false
)
{
exec_mu_
.
reset
(
new
std
::
mutex
());
}
const
Status
&
Payload
::
MergePayload
(
std
::
shared_ptr
<
Payload
>&
payload
)
{
if
((
payload
->
GetOpType
()
!=
Operation
::
INFER_RUN
)
||
(
op_type_
!=
Operation
::
INFER_RUN
))
{
static
Status
op_type_error
(
Status
::
Code
::
INTERNAL
,
"Attempted to merge payloads of type that are not INFER_RUN"
);
return
op_type_error
;
}
if
(
payload
->
GetInstance
()
!=
instance_
)
{
static
Status
instance_error
(
Status
::
Code
::
INTERNAL
,
"Attempted to merge payloads of mismatching instance"
);
return
instance_error
;
}
if
((
payload
->
GetState
()
!=
State
::
EXECUTING
)
||
(
state_
!=
State
::
EXECUTING
))
{
static
Status
state_error
(
Status
::
Code
::
INTERNAL
,
"Attempted to merge payloads that are not in executing state"
);
return
state_error
;
}
// Skip comparison if not initialized (required), here assume either all
// payloads are initialized or otherwise.
if
(
required_equal_inputs_
.
Initialized
()
&&
!
required_equal_inputs_
.
HasEqualInputs
(
*
payload
->
Requests
().
begin
()))
{
static
Status
shape_error
(
Status
::
Code
::
INVALID_ARG
,
"Attempted to merge payloads that has non-equal inputs"
);
return
shape_error
;
}
requests_
.
insert
(
requests_
.
end
(),
std
::
make_move_iterator
(
payload
->
Requests
().
begin
()),
std
::
make_move_iterator
(
payload
->
Requests
().
end
()));
payload
->
Callback
();
return
Status
::
Success
;
}
void
Payload
::
Reset
(
const
Operation
op_type
,
TritonModelInstance
*
instance
)
{
op_type_
=
op_type
;
requests_
.
clear
();
OnCallback_
=
[]()
{};
release_callbacks_
.
clear
();
instance_
=
instance
;
state_
=
State
::
UNINITIALIZED
;
status_
.
reset
(
new
std
::
promise
<
Status
>
());
required_equal_inputs_
=
RequiredEqualInputs
();
batcher_start_ns_
=
0
;
saturated_
=
false
;
}
void
Payload
::
Release
()
{
op_type_
=
Operation
::
INFER_RUN
;
requests_
.
clear
();
OnCallback_
=
[]()
{};
release_callbacks_
.
clear
();
instance_
=
nullptr
;
state_
=
State
::
RELEASED
;
required_equal_inputs_
=
RequiredEqualInputs
();
batcher_start_ns_
=
0
;
saturated_
=
false
;
}
size_t
Payload
::
BatchSize
()
{
size_t
batch_size
=
0
;
for
(
const
auto
&
request
:
requests_
)
{
batch_size
+=
std
::
max
(
1U
,
request
->
BatchSize
());
}
return
batch_size
;
}
void
Payload
::
ReserveRequests
(
size_t
size
)
{
requests_
.
reserve
(
size
);
}
void
Payload
::
AddRequest
(
std
::
unique_ptr
<
InferenceRequest
>
request
)
{
if
((
batcher_start_ns_
==
0
)
||
(
batcher_start_ns_
>
request
->
BatcherStartNs
()))
{
batcher_start_ns_
=
request
->
BatcherStartNs
();
}
requests_
.
push_back
(
std
::
move
(
request
));
}
void
Payload
::
SetCallback
(
std
::
function
<
void
()
>
OnCallback
)
{
OnCallback_
=
OnCallback
;
}
void
Payload
::
SetInstance
(
TritonModelInstance
*
model_instance
)
{
instance_
=
model_instance
;
}
void
Payload
::
AddInternalReleaseCallback
(
std
::
function
<
void
()
>&&
callback
)
{
release_callbacks_
.
emplace_back
(
std
::
move
(
callback
));
}
void
Payload
::
MarkSaturated
()
{
saturated_
=
true
;
}
void
Payload
::
SetState
(
Payload
::
State
state
)
{
state_
=
state
;
}
Status
Payload
::
Wait
()
{
return
status_
->
get_future
().
get
();
}
void
Payload
::
Callback
()
{
OnCallback_
();
}
void
Payload
::
OnRelease
()
{
// Invoke the release callbacks added internally before releasing the
// request to user provided callback.
for
(
auto
it
=
release_callbacks_
.
rbegin
();
it
!=
release_callbacks_
.
rend
();
it
++
)
{
(
*
it
)();
}
release_callbacks_
.
clear
();
}
void
Payload
::
Execute
(
bool
*
should_exit
)
{
*
should_exit
=
false
;
Status
status
;
switch
(
op_type_
)
{
case
Operation
::
INFER_RUN
:
instance_
->
Schedule
(
std
::
move
(
requests_
),
OnCallback_
);
break
;
case
Operation
::
INIT
:
status
=
instance_
->
Initialize
();
break
;
case
Operation
::
WARM_UP
:
status
=
instance_
->
WarmUp
();
break
;
case
Operation
::
EXIT
:
*
should_exit
=
true
;
}
status_
->
set_value
(
status
);
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/payload.h
0 → 100644
View file @
374c78ca
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <vector>
#include "backend_model_instance.h"
#include "infer_request.h"
#include "scheduler_utils.h"
#include "status.h"
namespace
triton
{
namespace
core
{
class
Payload
{
public:
enum
Operation
{
INFER_RUN
=
0
,
INIT
=
1
,
WARM_UP
=
2
,
EXIT
=
3
};
enum
State
{
UNINITIALIZED
=
0
,
READY
=
1
,
REQUESTED
=
2
,
SCHEDULED
=
3
,
EXECUTING
=
4
,
RELEASED
=
5
};
Payload
();
void
Reset
(
const
Operation
op_type
,
TritonModelInstance
*
instance
=
nullptr
);
const
Status
&
MergePayload
(
std
::
shared_ptr
<
Payload
>&
payload
);
Operation
GetOpType
()
{
return
op_type_
;
}
std
::
mutex
*
GetExecMutex
()
{
return
exec_mu_
.
get
();
}
size_t
RequestCount
()
{
return
requests_
.
size
();
}
size_t
BatchSize
();
void
ReserveRequests
(
size_t
size
);
void
AddRequest
(
std
::
unique_ptr
<
InferenceRequest
>
request
);
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>&
Requests
()
{
return
requests_
;
}
uint64_t
BatcherStartNs
()
{
return
batcher_start_ns_
;
}
void
SetCallback
(
std
::
function
<
void
()
>
OnCallback
);
void
Callback
();
void
AddInternalReleaseCallback
(
std
::
function
<
void
()
>&&
callback
);
void
OnRelease
();
void
SetInstance
(
TritonModelInstance
*
model_instance
);
TritonModelInstance
*
GetInstance
()
{
return
instance_
;
}
void
MarkSaturated
();
bool
IsSaturated
()
{
return
saturated_
;
}
RequiredEqualInputs
*
MutableRequiredEqualInputs
()
{
return
&
required_equal_inputs_
;
}
State
GetState
()
{
return
state_
;
}
void
SetState
(
State
state
);
void
Execute
(
bool
*
should_exit
);
Status
Wait
();
void
Release
();
private:
Operation
op_type_
;
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>
requests_
;
std
::
function
<
void
()
>
OnCallback_
;
std
::
vector
<
std
::
function
<
void
()
>>
release_callbacks_
;
TritonModelInstance
*
instance_
;
State
state_
;
std
::
unique_ptr
<
std
::
promise
<
Status
>>
status_
;
std
::
unique_ptr
<
std
::
mutex
>
exec_mu_
;
uint64_t
batcher_start_ns_
;
RequiredEqualInputs
required_equal_inputs_
;
bool
saturated_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/pinned_memory_manager.cc
0 → 100644
View file @
374c78ca
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "pinned_memory_manager.h"
#include <sstream>
#include "numa_utils.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace
triton
{
namespace
core
{
namespace
{
std
::
string
PointerToString
(
void
*
ptr
)
{
std
::
stringstream
ss
;
ss
<<
ptr
;
return
ss
.
str
();
}
Status
ParseIntOption
(
const
std
::
string
&
msg
,
const
std
::
string
&
arg
,
int
*
value
)
{
try
{
*
value
=
std
::
stoi
(
arg
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
msg
+
": Can't parse '"
+
arg
+
"' to integer"
);
}
return
Status
::
Success
;
}
}
// namespace
std
::
unique_ptr
<
PinnedMemoryManager
>
PinnedMemoryManager
::
instance_
;
uint64_t
PinnedMemoryManager
::
pinned_memory_byte_size_
;
PinnedMemoryManager
::
PinnedMemory
::
PinnedMemory
(
void
*
pinned_memory_buffer
,
uint64_t
size
)
:
pinned_memory_buffer_
(
pinned_memory_buffer
)
{
if
(
pinned_memory_buffer_
!=
nullptr
)
{
managed_pinned_memory_
=
boost
::
interprocess
::
managed_external_buffer
(
boost
::
interprocess
::
create_only_t
{},
pinned_memory_buffer_
,
size
);
}
}
PinnedMemoryManager
::
PinnedMemory
::~
PinnedMemory
()
{
#ifdef TRITON_ENABLE_GPU
if
(
pinned_memory_buffer_
!=
nullptr
)
{
cudaFreeHost
(
pinned_memory_buffer_
);
}
#endif // TRITON_ENABLE_GPU
}
PinnedMemoryManager
::~
PinnedMemoryManager
()
{
// Clean up
for
(
const
auto
&
memory_info
:
memory_info_
)
{
const
auto
&
is_pinned
=
memory_info
.
second
.
first
;
if
(
!
is_pinned
)
{
free
(
memory_info
.
first
);
}
}
}
void
PinnedMemoryManager
::
AddPinnedMemoryBuffer
(
const
std
::
shared_ptr
<
PinnedMemory
>&
pinned_memory_buffer
,
unsigned
long
node_mask
)
{
pinned_memory_buffers_
[
node_mask
]
=
pinned_memory_buffer
;
}
Status
PinnedMemoryManager
::
AllocInternal
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
,
PinnedMemory
*
pinned_memory_buffer
)
{
auto
status
=
Status
::
Success
;
if
(
pinned_memory_buffer
->
pinned_memory_buffer_
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
pinned_memory_buffer
->
buffer_mtx_
);
*
ptr
=
pinned_memory_buffer
->
managed_pinned_memory_
.
allocate
(
size
,
std
::
nothrow_t
{});
*
allocated_type
=
TRITONSERVER_MEMORY_CPU_PINNED
;
if
(
*
ptr
==
nullptr
)
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate pinned system memory"
);
}
}
else
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate pinned system memory: no pinned memory pool"
);
}
bool
is_pinned
=
true
;
if
((
!
status
.
IsOk
())
&&
allow_nonpinned_fallback
)
{
static
bool
warning_logged
=
false
;
if
(
!
warning_logged
)
{
LOG_WARNING
<<
status
.
Message
()
<<
", falling back to non-pinned system memory"
;
warning_logged
=
true
;
}
*
ptr
=
malloc
(
size
);
*
allocated_type
=
TRITONSERVER_MEMORY_CPU
;
is_pinned
=
false
;
if
(
*
ptr
==
nullptr
)
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate non-pinned system memory"
);
}
else
{
status
=
Status
::
Success
;
}
}
// keep track of allocated buffer or clean up
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
info_mtx_
);
if
(
status
.
IsOk
())
{
auto
res
=
memory_info_
.
emplace
(
*
ptr
,
std
::
make_pair
(
is_pinned
,
pinned_memory_buffer
));
if
(
!
res
.
second
)
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"unexpected memory address collision, '"
+
PointerToString
(
*
ptr
)
+
"' has been managed"
);
}
LOG_VERBOSE
(
1
)
<<
(
is_pinned
?
""
:
"non-"
)
<<
"pinned memory allocation: "
<<
"size "
<<
size
<<
", addr "
<<
*
ptr
;
}
}
if
((
!
status
.
IsOk
())
&&
(
*
ptr
!=
nullptr
))
{
if
(
is_pinned
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
pinned_memory_buffer
->
buffer_mtx_
);
pinned_memory_buffer
->
managed_pinned_memory_
.
deallocate
(
*
ptr
);
}
else
{
free
(
*
ptr
);
}
}
return
status
;
}
Status
PinnedMemoryManager
::
FreeInternal
(
void
*
ptr
)
{
bool
is_pinned
=
true
;
PinnedMemory
*
pinned_memory_buffer
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
info_mtx_
);
auto
it
=
memory_info_
.
find
(
ptr
);
if
(
it
!=
memory_info_
.
end
())
{
is_pinned
=
it
->
second
.
first
;
pinned_memory_buffer
=
it
->
second
.
second
;
LOG_VERBOSE
(
1
)
<<
(
is_pinned
?
""
:
"non-"
)
<<
"pinned memory deallocation: "
<<
"addr "
<<
ptr
;
memory_info_
.
erase
(
it
);
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"unexpected memory address '"
+
PointerToString
(
ptr
)
+
"' is not being managed"
);
}
}
if
(
is_pinned
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
pinned_memory_buffer
->
buffer_mtx_
);
pinned_memory_buffer
->
managed_pinned_memory_
.
deallocate
(
ptr
);
}
else
{
free
(
ptr
);
}
return
Status
::
Success
;
}
void
PinnedMemoryManager
::
Reset
()
{
instance_
.
reset
();
}
Status
PinnedMemoryManager
::
Create
(
const
Options
&
options
)
{
if
(
instance_
!=
nullptr
)
{
LOG_WARNING
<<
"New pinned memory pool of size "
<<
options
.
pinned_memory_pool_byte_size_
<<
" could not be created since one already exists"
<<
" of size "
<<
pinned_memory_byte_size_
;
return
Status
::
Success
;
}
instance_
.
reset
(
new
PinnedMemoryManager
());
if
(
options
.
host_policy_map_
.
empty
())
{
void
*
buffer
=
nullptr
;
#ifdef TRITON_ENABLE_GPU
auto
err
=
cudaHostAlloc
(
&
buffer
,
options
.
pinned_memory_pool_byte_size_
,
cudaHostAllocPortable
);
if
(
err
!=
cudaSuccess
)
{
buffer
=
nullptr
;
LOG_WARNING
<<
"Unable to allocate pinned system memory, pinned memory "
"pool will not be available: "
<<
std
::
string
(
cudaGetErrorString
(
err
));
}
else
if
(
options
.
pinned_memory_pool_byte_size_
!=
0
)
{
LOG_INFO
<<
"Pinned memory pool is created at '"
<<
PointerToString
(
buffer
)
<<
"' with size "
<<
options
.
pinned_memory_pool_byte_size_
;
}
else
{
LOG_INFO
<<
"Pinned memory pool disabled"
;
}
#endif // TRITON_ENABLE_GPU
try
{
instance_
->
AddPinnedMemoryBuffer
(
std
::
shared_ptr
<
PinnedMemory
>
(
new
PinnedMemory
(
buffer
,
options
.
pinned_memory_pool_byte_size_
)),
0
);
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to add Pinned Memory buffer: "
+
std
::
string
(
ex
.
what
()));
}
}
else
{
// Create only one buffer / manager should be created for one node,
// and all associated devices should request memory from the shared manager
std
::
map
<
int32_t
,
std
::
string
>
numa_map
;
for
(
const
auto
host_policy
:
options
.
host_policy_map_
)
{
const
auto
numa_it
=
host_policy
.
second
.
find
(
"numa-node"
);
if
(
numa_it
!=
host_policy
.
second
.
end
())
{
int32_t
numa_id
;
if
(
ParseIntOption
(
"Parsing NUMA node"
,
numa_it
->
second
,
&
numa_id
)
.
IsOk
())
{
numa_map
.
emplace
(
numa_id
,
host_policy
.
first
);
}
}
}
for
(
const
auto
node_policy
:
numa_map
)
{
auto
status
=
SetNumaMemoryPolicy
(
options
.
host_policy_map_
.
at
(
node_policy
.
second
));
if
(
!
status
.
IsOk
())
{
LOG_WARNING
<<
"Unable to allocate pinned system memory for NUMA node "
<<
node_policy
.
first
<<
": "
<<
status
.
AsString
();
continue
;
}
unsigned
long
node_mask
;
status
=
GetNumaMemoryPolicyNodeMask
(
&
node_mask
);
if
(
!
status
.
IsOk
())
{
LOG_WARNING
<<
"Unable to get NUMA node set for current thread: "
<<
status
.
AsString
();
continue
;
}
void
*
buffer
=
nullptr
;
#ifdef TRITON_ENABLE_GPU
auto
err
=
cudaHostAlloc
(
&
buffer
,
options
.
pinned_memory_pool_byte_size_
,
cudaHostAllocPortable
);
if
(
err
!=
cudaSuccess
)
{
buffer
=
nullptr
;
LOG_WARNING
<<
"Unable to allocate pinned system memory, pinned memory "
"pool will not be available: "
<<
std
::
string
(
cudaGetErrorString
(
err
));
}
else
if
(
options
.
pinned_memory_pool_byte_size_
!=
0
)
{
LOG_INFO
<<
"Pinned memory pool is created at '"
<<
PointerToString
(
buffer
)
<<
"' with size "
<<
options
.
pinned_memory_pool_byte_size_
;
}
else
{
LOG_INFO
<<
"Pinned memory pool disabled"
;
}
#endif // TRITON_ENABLE_GPU
ResetNumaMemoryPolicy
();
try
{
instance_
->
AddPinnedMemoryBuffer
(
std
::
shared_ptr
<
PinnedMemory
>
(
new
PinnedMemory
(
buffer
,
options
.
pinned_memory_pool_byte_size_
)),
node_mask
);
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to add Pinned Memory buffer with host policy: "
+
std
::
string
(
ex
.
what
()));
}
}
// If no pinned memory is allocated, add an empty entry where all allocation
// will be on normal system memory
if
(
instance_
->
pinned_memory_buffers_
.
empty
())
{
try
{
instance_
->
AddPinnedMemoryBuffer
(
std
::
shared_ptr
<
PinnedMemory
>
(
new
PinnedMemory
(
nullptr
,
options
.
pinned_memory_pool_byte_size_
)),
0
);
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to add empty Pinned Memory entry: "
+
std
::
string
(
ex
.
what
()));
}
}
}
pinned_memory_byte_size_
=
options
.
pinned_memory_pool_byte_size_
;
return
Status
::
Success
;
}
Status
PinnedMemoryManager
::
Alloc
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
)
{
if
(
instance_
==
nullptr
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"PinnedMemoryManager has not been created"
);
}
auto
pinned_memory_buffer
=
instance_
->
pinned_memory_buffers_
.
begin
()
->
second
.
get
();
if
(
instance_
->
pinned_memory_buffers_
.
size
()
>
1
)
{
unsigned
long
node_mask
;
if
(
GetNumaMemoryPolicyNodeMask
(
&
node_mask
).
IsOk
())
{
auto
it
=
instance_
->
pinned_memory_buffers_
.
find
(
node_mask
);
if
(
it
!=
instance_
->
pinned_memory_buffers_
.
end
())
{
pinned_memory_buffer
=
it
->
second
.
get
();
}
}
}
return
instance_
->
AllocInternal
(
ptr
,
size
,
allocated_type
,
allow_nonpinned_fallback
,
pinned_memory_buffer
);
}
Status
PinnedMemoryManager
::
Free
(
void
*
ptr
)
{
if
(
instance_
==
nullptr
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"PinnedMemoryManager has not been created"
);
}
return
instance_
->
FreeInternal
(
ptr
);
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/pinned_memory_manager.h
0 → 100644
View file @
374c78ca
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <boost/interprocess/managed_external_buffer.hpp>
#include <map>
#include <memory>
#include <mutex>
#include "status.h"
#include "triton/common/model_config.h"
namespace
triton
{
namespace
core
{
// This is a singleton class responsible for maintaining pinned memory pool
// used by the inference server. Pinned memory allocations and deallocations
// must be requested via functions provided by this class.
class
PinnedMemoryManager
{
public:
// Options to configure pinned memeory manager.
struct
Options
{
Options
(
uint64_t
b
=
0
,
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map
=
{})
:
pinned_memory_pool_byte_size_
(
b
),
host_policy_map_
(
host_policy_map
)
{
}
uint64_t
pinned_memory_pool_byte_size_
;
triton
::
common
::
HostPolicyCmdlineConfigMap
host_policy_map_
;
};
~
PinnedMemoryManager
();
// Create the pinned memory manager based on 'options' specified.
// Return Status object indicating success or failure.
static
Status
Create
(
const
Options
&
options
);
// Allocate pinned memory with the requested 'size' and return the pointer
// in 'ptr'. If 'allow_nonpinned_fallback' is true, regular system memory
// will be allocated as fallback in the case where pinned memory fails to
// be allocated.
// Return Status object indicating success or failure.
static
Status
Alloc
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
);
// Free the memory allocated by the pinned memory manager.
// Return Status object indicating success or failure.
static
Status
Free
(
void
*
ptr
);
protected:
// Provide explicit control on the lifecycle of the CUDA memory manager,
// for testing only.
static
void
Reset
();
private:
class
PinnedMemory
{
public:
PinnedMemory
(
void
*
pinned_memory_buffer
,
uint64_t
size
);
~
PinnedMemory
();
void
*
pinned_memory_buffer_
;
std
::
mutex
buffer_mtx_
;
boost
::
interprocess
::
managed_external_buffer
managed_pinned_memory_
;
};
PinnedMemoryManager
()
=
default
;
Status
AllocInternal
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
,
PinnedMemory
*
pinned_memory_buffer
);
Status
FreeInternal
(
void
*
ptr
);
void
AddPinnedMemoryBuffer
(
const
std
::
shared_ptr
<
PinnedMemory
>&
pinned_memory_buffer
,
unsigned
long
node_mask
);
static
std
::
unique_ptr
<
PinnedMemoryManager
>
instance_
;
static
uint64_t
pinned_memory_byte_size_
;
std
::
mutex
info_mtx_
;
std
::
map
<
void
*
,
std
::
pair
<
bool
,
PinnedMemory
*>>
memory_info_
;
std
::
map
<
unsigned
long
,
std
::
shared_ptr
<
PinnedMemory
>>
pinned_memory_buffers_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/rate_limiter.cc
0 → 100644
View file @
374c78ca
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rate_limiter.h"
#include <limits>
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
constexpr
size_t
MAX_PAYLOAD_BUCKET_COUNT
=
1000
;
//=========================================================================
// Core Implementation
//=========================================================================
Status
RateLimiter
::
Create
(
const
bool
ignore_resources_and_priority
,
const
RateLimiter
::
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
RateLimiter
>*
rate_limiter
)
{
std
::
unique_ptr
<
RateLimiter
>
local_rate_limiter
(
new
RateLimiter
(
ignore_resources_and_priority
,
resource_map
));
*
rate_limiter
=
std
::
move
(
local_rate_limiter
);
return
Status
::
Success
;
}
Status
RateLimiter
::
RegisterModelInstance
(
TritonModelInstance
*
triton_model_instance
,
const
RateLimiterConfig
&
rate_limiter_config
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_ctx_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
model_instance_ctx_mtx_
);
auto
&
model_context
=
model_contexts_
[
triton_model_instance
->
Model
()];
auto
&
model_instances
=
model_instance_ctxs_
[
triton_model_instance
->
Model
()];
model_instances
.
push_back
(
std
::
shared_ptr
<
ModelInstanceContext
>
(
new
ModelInstanceContext
(
triton_model_instance
,
&
model_context
,
rate_limiter_config
,
[
this
](
ModelInstanceContext
*
instance
)
{
OnStage
(
instance
);
},
[
this
](
ModelInstanceContext
*
instance
)
{
OnRelease
(
instance
);
})));
model_context
.
AddAvailableInstance
(
model_instances
.
back
().
get
());
model_context
.
AddSpecificRequestQueue
();
if
(
!
ignore_resources_and_priority_
)
{
resource_manager_
->
AddModelInstance
(
model_instances
.
back
().
get
());
RETURN_IF_ERROR
(
resource_manager_
->
UpdateResourceLimits
());
}
}
InitializePayloadQueues
(
triton_model_instance
);
return
Status
::
Success
;
}
Status
RateLimiter
::
UnregisterModel
(
const
TritonModel
*
model
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_ctx_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
model_instance_ctx_mtx_
);
auto
&
model_context
=
model_contexts_
[
model
];
model_context
.
RequestRemoval
();
for
(
const
auto
&
instance
:
model_instance_ctxs_
[
model
])
{
instance
->
WaitForRemoval
();
if
(
!
ignore_resources_and_priority_
)
{
resource_manager_
->
RemoveModelInstance
(
instance
.
get
());
}
}
model_instance_ctxs_
.
erase
(
model
);
model_contexts_
.
erase
(
model
);
}
if
(
!
ignore_resources_and_priority_
)
{
RETURN_IF_ERROR
(
resource_manager_
->
UpdateResourceLimits
());
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queues_mu_
);
if
(
payload_queues_
.
find
(
model
)
!=
payload_queues_
.
end
())
{
payload_queues_
.
erase
(
model
);
}
}
return
Status
::
Success
;
}
bool
RateLimiter
::
PayloadSlotAvailable
(
const
TritonModel
*
model
)
{
bool
result
;
PayloadQueue
*
payload_queue
=
payload_queues_
[
model
].
get
();
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
result
=
payload_queue
->
queue_
->
Size
()
<
2
*
payload_queue
->
specific_queues_
.
size
();
}
return
result
;
}
Status
RateLimiter
::
EnqueuePayload
(
const
TritonModel
*
model
,
std
::
shared_ptr
<
Payload
>
payload
)
{
auto
pinstance
=
payload
->
GetInstance
();
if
(
payload_queues_
.
find
(
model
)
==
payload_queues_
.
end
())
{
LOG_INFO
<<
"Should not print this "
;
}
PayloadQueue
*
payload_queue
=
payload_queues_
[
model
].
get
();
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
payload
->
SetState
(
Payload
::
State
::
REQUESTED
);
if
(
ignore_resources_and_priority_
)
{
SchedulePayload
(
pinstance
,
payload_queue
,
payload
);
}
}
if
(
ignore_resources_and_priority_
)
{
if
(
pinstance
==
nullptr
)
{
payload_queue
->
cv_
.
notify_one
();
}
else
{
payload_queue
->
cv_
.
notify_all
();
}
}
else
{
StandardScheduleFunc
sched_func
=
[
this
,
payload_queue
,
payload
](
ModelInstanceContext
*
mi
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
this
->
SchedulePayload
(
mi
->
RawInstance
(),
payload_queue
,
payload
);
}
auto
cb
=
[
mi
]()
{
mi
->
Release
();
};
payload
->
AddInternalReleaseCallback
(
cb
);
if
(
mi
->
RawInstance
()
==
nullptr
)
{
payload_queue
->
cv_
.
notify_one
();
}
else
{
payload_queue
->
cv_
.
notify_all
();
}
};
DeferPayloadSchedule
(
sched_func
,
model
,
payload
->
GetInstance
());
}
return
Status
::
Success
;
}
void
RateLimiter
::
DequeuePayload
(
std
::
deque
<
TritonModelInstance
*>&
instances
,
std
::
shared_ptr
<
Payload
>*
payload
)
{
payload
->
reset
();
if
(
payload_queues_
.
find
(
instances
[
0
]
->
Model
())
==
payload_queues_
.
end
())
{
LOG_INFO
<<
"Should not print this "
;
}
PayloadQueue
*
payload_queue
=
payload_queues_
[
instances
[
0
]
->
Model
()].
get
();
std
::
vector
<
std
::
shared_ptr
<
Payload
>>
merged_payloads
;
size_t
instance_index
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
payload_queue
->
cv_
.
wait
(
lk
,
[
&
instances
,
&
instance_index
,
payload_queue
]()
{
bool
empty
=
payload_queue
->
queue_
->
Empty
();
if
(
empty
)
{
instance_index
=
0
;
for
(
const
auto
instance
:
instances
)
{
empty
=
payload_queue
->
specific_queues_
[
instance
]
->
Empty
();
if
(
empty
)
{
instance_index
++
;
}
else
{
break
;
}
}
}
return
!
empty
;
});
if
(
instance_index
<
instances
.
size
())
{
TritonModelInstance
*
instance
=
instances
[
instance_index
];
if
(
!
payload_queue
->
specific_queues_
[
instance
]
->
Empty
())
{
payload_queue
->
specific_queues_
[
instance
]
->
Dequeue
(
payload
,
&
merged_payloads
);
}
}
else
{
payload_queue
->
queue_
->
Dequeue
(
payload
,
&
merged_payloads
);
}
}
for
(
auto
&
merge_payload
:
merged_payloads
)
{
PayloadRelease
(
merge_payload
);
}
(
*
payload
)
->
Callback
();
if
((
*
payload
)
->
GetInstance
()
==
nullptr
)
{
(
*
payload
)
->
SetInstance
(
instances
.
front
());
instances
.
pop_front
();
}
else
{
instances
.
erase
(
instances
.
begin
()
+
instance_index
);
}
}
std
::
shared_ptr
<
Payload
>
RateLimiter
::
GetPayload
(
const
Payload
::
Operation
op_type
,
TritonModelInstance
*
instance
)
{
std
::
shared_ptr
<
Payload
>
payload
;
if
(
max_payload_bucket_count_
>
0
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
payload_mu_
);
if
(
!
payload_bucket_
.
empty
())
{
payload
=
payload_bucket_
.
back
();
payload_bucket_
.
pop_back
();
}
if
(
payload
.
get
()
==
nullptr
&&
(
!
payloads_in_use_
.
empty
()))
{
// Just checking the front of the queue instead the entire queue for
// an available payload to save time.
if
(
payloads_in_use_
.
front
().
use_count
()
==
1
)
{
payload
=
payloads_in_use_
.
front
();
payloads_in_use_
.
pop_front
();
}
}
}
if
(
payload
.
get
()
==
nullptr
)
{
payload
.
reset
(
new
Payload
());
}
payload
->
Reset
(
op_type
,
instance
);
return
payload
;
}
void
RateLimiter
::
PayloadRelease
(
std
::
shared_ptr
<
Payload
>&
payload
)
{
payload
->
OnRelease
();
if
(
max_payload_bucket_count_
>
0
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
payload_mu_
);
if
(
payloads_in_use_
.
size
()
+
payload_bucket_
.
size
()
<
max_payload_bucket_count_
)
{
// Release iff the payload shared_ptr is uniquely held.
if
(
payload
.
use_count
()
==
1
)
{
payload
->
Release
();
payload_bucket_
.
push_back
(
std
::
move
(
payload
));
return
;
}
else
{
payloads_in_use_
.
push_back
(
std
::
move
(
payload
));
}
}
}
}
RateLimiter
::
RateLimiter
(
const
bool
ignore_resources_and_priority
,
const
ResourceMap
&
resource_map
)
:
ignore_resources_and_priority_
(
ignore_resources_and_priority
),
max_payload_bucket_count_
(
MAX_PAYLOAD_BUCKET_COUNT
)
{
ResourceManager
::
Create
(
resource_map
,
&
resource_manager_
);
}
void
RateLimiter
::
InitializePayloadQueues
(
const
TritonModelInstance
*
instance
)
{
auto
&
config
=
instance
->
Model
()
->
Config
();
uint64_t
max_queue_delay_microseconds
;
if
(
config
.
has_sequence_batching
())
{
const
auto
&
batcher_config
=
config
.
sequence_batching
();
if
(
batcher_config
.
has_oldest
())
{
max_queue_delay_microseconds
=
batcher_config
.
oldest
().
max_queue_delay_microseconds
();
}
else
{
max_queue_delay_microseconds
=
0
;
}
}
else
if
(
config
.
has_dynamic_batching
())
{
max_queue_delay_microseconds
=
config
.
dynamic_batching
().
max_queue_delay_microseconds
();
}
else
{
max_queue_delay_microseconds
=
0
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queues_mu_
);
if
(
payload_queues_
.
find
(
instance
->
Model
())
==
payload_queues_
.
end
())
{
payload_queues_
.
emplace
(
instance
->
Model
(),
new
PayloadQueue
(
config
.
max_batch_size
(),
max_queue_delay_microseconds
*
1000
));
}
}
PayloadQueue
*
payload_queue
=
payload_queues_
[
instance
->
Model
()].
get
();
if
(
payload_queue
->
specific_queues_
.
find
(
instance
)
==
payload_queue
->
specific_queues_
.
end
())
{
payload_queue
->
specific_queues_
.
emplace
(
instance
,
new
InstanceQueue
(
config
.
max_batch_size
(),
max_queue_delay_microseconds
*
1000
));
}
}
Status
RateLimiter
::
DeferPayloadSchedule
(
const
StandardScheduleFunc
&
OnSchedule
,
const
TritonModel
*
model
,
TritonModelInstance
*
triton_model_instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
model_ctx_mtx_
);
auto
itr
=
model_contexts_
.
find
(
model
);
if
(
itr
==
model_contexts_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Requested model is not yet registered with rate limiter"
);
}
if
(
itr
->
second
.
isRemovalInProgress
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"New model requests can not be made to a model that is being "
"removed"
);
}
itr
->
second
.
EnqueueModelInstanceRequest
(
OnSchedule
,
triton_model_instance
);
itr
->
second
.
StageInstanceIfAvailable
(
triton_model_instance
);
return
Status
::
Success
;
}
void
RateLimiter
::
SchedulePayload
(
TritonModelInstance
*
tmi
,
PayloadQueue
*
payload_queue
,
const
std
::
shared_ptr
<
Payload
>&
payload
)
{
if
(
tmi
==
nullptr
)
{
payload_queue
->
queue_
->
Enqueue
(
payload
);
}
else
{
payload_queue
->
specific_queues_
[
tmi
]
->
Enqueue
(
payload
);
}
payload
->
SetState
(
Payload
::
State
::
SCHEDULED
);
}
void
RateLimiter
::
OnStage
(
ModelInstanceContext
*
instance
)
{
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
staged_instances_mtx_
);
staged_instances_
.
push
(
instance
);
}
AttemptAllocation
();
}
void
RateLimiter
::
OnRelease
(
ModelInstanceContext
*
instance
)
{
auto
&
model_context
=
model_contexts_
[
instance
->
RawInstance
()
->
Model
()];
model_context
.
AddAvailableInstance
(
instance
);
resource_manager_
->
ReleaseResources
(
instance
);
if
(
model_context
.
ContainsPendingRequests
(
instance
->
RawInstance
()
->
Index
()))
{
model_context
.
StageInstanceIfAvailable
(
instance
->
RawInstance
());
}
AttemptAllocation
();
}
void
RateLimiter
::
AttemptAllocation
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
staged_instances_mtx_
);
if
(
!
staged_instances_
.
empty
())
{
ModelInstanceContext
*
instance
=
staged_instances_
.
top
();
if
(
resource_manager_
->
AllocateResources
(
instance
))
{
staged_instances_
.
pop
();
instance
->
Allocate
();
}
}
}
//=========================================================================
// ModelContext Implementation
//=========================================================================
RateLimiter
::
ModelContext
::
ModelContext
()
:
removal_in_progress_
(
false
)
{}
Status
RateLimiter
::
ModelContext
::
EnqueueModelInstanceRequest
(
const
StandardScheduleFunc
&
OnSchedule
,
TritonModelInstance
*
triton_model_instance
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
sched_request_queue_mtx_
);
if
(
triton_model_instance
==
nullptr
)
{
generic_sched_request_queue_
.
push
(
OnSchedule
);
}
else
if
(
(
uint32_t
)
triton_model_instance
->
Index
()
<
specific_sched_request_queues_
.
size
())
{
specific_sched_request_queues_
[
triton_model_instance
->
Index
()].
push
(
OnSchedule
);
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"expected instance index between 0 and "
+
std
::
to_string
(
specific_sched_request_queues_
.
size
())
+
", got "
+
std
::
to_string
(
triton_model_instance
->
Index
()));
}
return
Status
::
Success
;
}
void
RateLimiter
::
ModelContext
::
AddAvailableInstance
(
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
avbl_instances_mtx_
);
avbl_instances_
.
push
(
instance
);
instance
->
MarkAvailable
();
}
void
RateLimiter
::
ModelContext
::
StageInstanceIfAvailable
(
TritonModelInstance
*
req_instance
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk1
(
sched_request_queue_mtx_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lk2
(
avbl_instances_mtx_
);
PriorityQueue
backup_queue
;
while
(
!
avbl_instances_
.
empty
())
{
ModelInstanceContext
*
instance
=
avbl_instances_
.
top
();
if
((
req_instance
!=
nullptr
)
&&
(
instance
->
RawInstance
()
!=
req_instance
))
{
backup_queue
.
push
(
instance
);
avbl_instances_
.
pop
();
continue
;
}
if
(
!
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
empty
())
{
// Prioritize the specific requests for the available model
// instance highest priority.
const
StandardScheduleFunc
func
=
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
front
();
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()].
pop
();
instance
->
Stage
(
func
);
}
else
if
(
!
generic_sched_request_queue_
.
empty
())
{
// If request is for generic model instance then use the
// instance with the highest priority.
const
StandardScheduleFunc
func
=
generic_sched_request_queue_
.
front
();
generic_sched_request_queue_
.
pop
();
instance
->
Stage
(
func
);
}
else
{
// If there are requests for a specific model instance then backup
// the model instance and keep searching through the available
// model instances. The prioritization will be taken care of in the
// staging priority queue.
backup_queue
.
push
(
instance
);
}
avbl_instances_
.
pop
();
}
// Restore the backup queue
if
(
!
backup_queue
.
empty
())
{
avbl_instances_
.
swap
(
backup_queue
);
}
}
void
RateLimiter
::
ModelContext
::
AllocateInstanceIfAvailable
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk1
(
sched_request_queue_mtx_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lk2
(
avbl_instances_mtx_
);
PriorityQueue
backup_queue
;
while
(
!
avbl_instances_
.
empty
())
{
ModelInstanceContext
*
instance
=
avbl_instances_
.
top
();
if
(
!
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
empty
())
{
// Prioritize the specific requests for the available model
// instance highest priority.
const
StandardScheduleFunc
func
=
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
front
();
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()].
pop
();
instance
->
DirectAllocate
(
func
);
}
else
if
(
!
generic_sched_request_queue_
.
empty
())
{
// If request is for generic model instance then use the
// instance with the highest priority.
const
StandardScheduleFunc
func
=
generic_sched_request_queue_
.
front
();
generic_sched_request_queue_
.
pop
();
instance
->
DirectAllocate
(
func
);
}
else
{
// If there are requests for a specific model instance then backup
// the model instance and keep searching through the available
// model instances. The prioritization will be taken care of in the
// staging priority queue.
backup_queue
.
push
(
instance
);
}
avbl_instances_
.
pop
();
}
// Restore the backup queue
if
(
!
backup_queue
.
empty
())
{
avbl_instances_
.
swap
(
backup_queue
);
}
}
void
RateLimiter
::
ModelContext
::
AddSpecificRequestQueue
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
sched_request_queue_mtx_
);
specific_sched_request_queues_
.
emplace_back
();
}
bool
RateLimiter
::
ModelContext
::
ContainsPendingRequests
(
int
index
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
sched_request_queue_mtx_
);
return
(
generic_sched_request_queue_
.
size
()
!=
0
)
||
(
specific_sched_request_queues_
[
index
].
size
()
!=
0
);
}
void
RateLimiter
::
ModelContext
::
RequestRemoval
()
{
removal_in_progress_
=
true
;
}
//=========================================================================
// ModelInstanceContext Implementation
//=========================================================================
RateLimiter
::
ModelInstanceContext
::
ModelInstanceContext
(
TritonModelInstance
*
triton_model_instance
,
RateLimiter
::
ModelContext
*
model_context
,
const
RateLimiter
::
RateLimiterConfig
&
rate_limiter_config
,
RateLimiter
::
StandardStageFunc
OnStage
,
RateLimiter
::
StandardReleaseFunc
OnRelease
)
:
triton_model_instance_
(
triton_model_instance
),
index_
(
triton_model_instance
->
Index
()),
model_context_
(
model_context
),
rate_limiter_config_
(
rate_limiter_config
),
OnStage_
(
OnStage
),
OnRelease_
(
OnRelease
),
exec_count_
(
0
),
state_
(
AVAILABLE
)
{
}
void
RateLimiter
::
ModelInstanceContext
::
MarkAvailable
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
state_
=
AVAILABLE
;
}
Status
RateLimiter
::
ModelInstanceContext
::
Stage
(
StandardScheduleFunc
OnSchedule
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
(
state_
!=
AVAILABLE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not stage a model instance that is not yet available"
);
}
state_
=
STAGED
;
OnSchedule_
=
OnSchedule
;
}
OnStage_
(
this
);
return
Status
::
Success
;
}
Status
RateLimiter
::
ModelInstanceContext
::
Allocate
()
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
(
state_
!=
STAGED
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not allocate a model instance that is not yet staged"
);
}
state_
=
ALLOCATED
;
}
OnSchedule_
(
this
);
return
Status
::
Success
;
}
Status
RateLimiter
::
ModelInstanceContext
::
DirectAllocate
(
StandardScheduleFunc
OnSchedule
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
(
state_
!=
AVAILABLE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not allocate a model instance that is not yet available"
);
}
state_
=
ALLOCATED
;
}
OnSchedule
(
this
);
return
Status
::
Success
;
}
void
RateLimiter
::
ModelInstanceContext
::
Release
()
{
exec_count_
++
;
OnRelease_
(
this
);
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
((
model_context_
->
isRemovalInProgress
())
&&
(
state_
==
AVAILABLE
)
&&
(
!
model_context_
->
ContainsPendingRequests
(
index_
)))
{
state_
=
REMOVED
;
}
}
if
(
state_
==
REMOVED
)
{
cv_
.
notify_all
();
}
}
void
RateLimiter
::
ModelInstanceContext
::
RequestRemoval
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
((
state_
==
AVAILABLE
)
&&
(
!
model_context_
->
ContainsPendingRequests
(
index_
)))
{
state_
=
REMOVED
;
}
}
void
RateLimiter
::
ModelInstanceContext
::
WaitForRemoval
()
{
if
(
!
model_context_
->
isRemovalInProgress
())
{
model_context_
->
RequestRemoval
();
}
RequestRemoval
();
// Wait for the instance to be removed
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
state_mtx_
);
cv_
.
wait
(
lk
,
[
this
]
{
return
state_
==
REMOVED
;
});
}
}
double
RateLimiter
::
ModelInstanceContext
::
ScaledPriority
()
{
// TODO: Different schemes for the prioritization of
// model instance can be added here.
// The priority of instance is 1 by default. If specified
// as 0, the priority is still treated as 1.
auto
priority
=
std
::
max
(
rate_limiter_config_
.
priority
(),
1u
);
return
(
exec_count_
*
priority
);
}
//=========================================================================
// ResourceManager Implementation
//=========================================================================
Status
RateLimiter
::
ResourceManager
::
Create
(
const
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
ResourceManager
>*
resource_manager
)
{
std
::
unique_ptr
<
ResourceManager
>
local_resource_manager
(
new
ResourceManager
(
resource_map
));
*
resource_manager
=
std
::
move
(
local_resource_manager
);
return
Status
::
Success
;
}
void
RateLimiter
::
ResourceManager
::
AddModelInstance
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
model_resources_mtx_
);
auto
pr
=
model_resources_
.
emplace
(
std
::
make_pair
(
instance
,
ResourceMap
()));
for
(
const
auto
&
resource
:
instance
->
GetRateLimiterConfig
()
->
resources
())
{
if
(
resource
.
global
())
{
(
pr
.
first
->
second
[
GLOBAL_RESOURCE_KEY
])[
resource
.
name
()]
=
resource
.
count
();
}
else
{
(
pr
.
first
->
second
[
instance
->
RawInstance
()
->
DeviceId
()])[
resource
.
name
()]
=
resource
.
count
();
}
}
}
Status
RateLimiter
::
ResourceManager
::
RemoveModelInstance
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
model_resources_mtx_
);
const
auto
&
itr
=
model_resources_
.
find
(
instance
);
if
(
itr
==
model_resources_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not find the instance to remove"
);
}
model_resources_
.
erase
(
instance
);
return
Status
::
Success
;
}
Status
RateLimiter
::
ResourceManager
::
UpdateResourceLimits
()
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
max_resources_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
model_resources_mtx_
);
max_resources_
.
clear
();
// Obtain the maximum resource across all the instances
// and use it as the default available.
for
(
const
auto
&
instance_resources
:
model_resources_
)
{
for
(
const
auto
&
resource_device_map
:
instance_resources
.
second
)
{
auto
ditr
=
max_resources_
.
find
(
resource_device_map
.
first
);
if
(
ditr
==
max_resources_
.
end
())
{
ditr
=
max_resources_
.
emplace
(
resource_device_map
.
first
,
resource_device_map
.
second
)
.
first
;
}
else
{
for
(
const
auto
resource
:
resource_device_map
.
second
)
{
auto
ritr
=
ditr
->
second
.
find
(
resource
.
first
);
if
(
ritr
==
ditr
->
second
.
end
())
{
ritr
=
ditr
->
second
.
emplace
(
resource
.
first
,
resource
.
second
).
first
;
}
else
{
if
(
ritr
->
second
<
resource
.
second
)
{
ritr
->
second
=
resource
.
second
;
}
}
}
}
}
}
if
(
!
explicit_max_resources_
.
empty
())
{
RETURN_IF_ERROR
(
ParseAndValidateExplicitResources
());
}
RETURN_IF_ERROR
(
ValidateMaxResources
());
if
(
LOG_VERBOSE_IS_ON
(
1
))
{
std
::
string
resource_map_str
{
"
\n
Max Resource Map===>
\n
"
};
for
(
const
auto
&
ditr
:
max_resources_
)
{
if
(
!
ditr
.
second
.
empty
())
{
std
::
string
device_str
{(
ditr
.
first
==
GLOBAL_RESOURCE_KEY
)
?
"GLOBAL"
:
std
::
to_string
(
ditr
.
first
)};
resource_map_str
+=
"
\t
Device: "
+
device_str
+
"
\n
"
;
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
resource_map_str
+=
"
\t\t
Resource: "
+
ritr
.
first
+
"
\t
Count: "
+
std
::
to_string
(
ritr
.
second
)
+
"
\n
"
;
}
}
}
LOG_VERBOSE
(
1
)
<<
resource_map_str
;
}
return
Status
::
Success
;
}
Status
RateLimiter
::
ResourceManager
::
ValidateMaxResources
()
{
for
(
const
auto
&
global_resource
:
max_resources_
[
GLOBAL_RESOURCE_KEY
])
{
for
(
const
auto
&
ditr
:
max_resources_
)
{
if
(
ditr
.
first
!=
GLOBAL_RESOURCE_KEY
)
{
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
if
(
global_resource
.
first
.
compare
(
ritr
.
first
)
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"Resource
\"
"
)
+
ritr
.
first
+
"
\"
is present as both global and device-specific resource in "
"the model configuration."
)
.
c_str
());
}
}
}
}
}
return
Status
::
Success
;
}
Status
RateLimiter
::
ResourceManager
::
ParseAndValidateExplicitResources
()
{
for
(
auto
&
ditr
:
max_resources_
)
{
for
(
auto
&
ritr
:
ditr
.
second
)
{
// If not specified explicitly, consider the resource to be unavailable.
size_t
resource_count
=
0
;
if
(
ditr
.
first
==
GLOBAL_RESOURCE_KEY
)
{
// Ignore the device specification... will search for all resources in
// the map...
for
(
const
auto
&
exp_ditr
:
explicit_max_resources_
)
{
for
(
const
auto
&
exp_ritr
:
exp_ditr
.
second
)
{
if
(
ritr
.
first
.
compare
(
exp_ritr
.
first
)
==
0
)
{
if
(
resource_count
<
exp_ritr
.
second
)
{
resource_count
=
exp_ritr
.
second
;
}
}
}
}
}
else
{
// Search only for the device specific or per-device resources...
// device-specific
for
(
const
auto
&
exp_ritr
:
explicit_max_resources_
[
ditr
.
first
])
{
if
(
ritr
.
first
.
compare
(
exp_ritr
.
first
)
==
0
)
{
if
(
resource_count
<
exp_ritr
.
second
)
{
resource_count
=
exp_ritr
.
second
;
}
}
}
// per-device
for
(
const
auto
&
exp_ritr
:
explicit_max_resources_
[
PER_DEVICE_RESOURCE_KEY
])
{
if
(
ritr
.
first
.
compare
(
exp_ritr
.
first
)
==
0
)
{
if
(
resource_count
<
exp_ritr
.
second
)
{
resource_count
=
exp_ritr
.
second
;
}
}
}
}
if
(
resource_count
<
ritr
.
second
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"Resource count for
\"
"
)
+
ritr
.
first
+
"
\"
is limited to "
+
std
::
to_string
(
resource_count
)
+
" which will prevent scheduling of one or more model "
"instances, the minimum required count is "
+
std
::
to_string
(
ritr
.
second
))
.
c_str
());
}
else
{
ritr
.
second
=
resource_count
;
}
}
}
return
Status
::
Success
;
}
bool
RateLimiter
::
ResourceManager
::
AllocateResources
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_resources_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
allocated_resources_mtx_
);
const
auto
&
itr
=
model_resources_
.
find
(
instance
);
if
(
itr
==
model_resources_
.
end
())
{
return
false
;
}
else
{
// First pass to verify if resources are available
{
std
::
lock_guard
<
std
::
mutex
>
lk3
(
max_resources_mtx_
);
for
(
const
auto
&
ditr
:
itr
->
second
)
{
auto
allocated_ditr
=
allocated_resources_
.
find
(
ditr
.
first
);
if
(
allocated_ditr
==
allocated_resources_
.
end
())
{
allocated_ditr
=
allocated_resources_
.
emplace
(
ditr
.
first
,
std
::
map
<
std
::
string
,
size_t
>
())
.
first
;
}
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
auto
allocated_ritr
=
allocated_ditr
->
second
.
find
(
ritr
.
first
);
if
(
allocated_ritr
==
allocated_ditr
->
second
.
end
())
{
allocated_ritr
=
allocated_ditr
->
second
.
emplace
(
ritr
.
first
,
0
).
first
;
}
if
((
allocated_ritr
->
second
+
ritr
.
second
)
>
(
max_resources_
[
ditr
.
first
])[
ritr
.
first
])
{
return
false
;
}
}
}
}
// Second pass to actually allocate the resources
for
(
const
auto
&
ditr
:
itr
->
second
)
{
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
(
allocated_resources_
[
ditr
.
first
])[
ritr
.
first
]
+=
ritr
.
second
;
}
}
}
return
true
;
}
Status
RateLimiter
::
ResourceManager
::
ReleaseResources
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_resources_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
allocated_resources_mtx_
);
const
auto
&
itr
=
model_resources_
.
find
(
instance
);
if
(
itr
==
model_resources_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unable find the instance resources to release"
);
}
else
{
for
(
const
auto
&
ditr
:
itr
->
second
)
{
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
(
allocated_resources_
[
ditr
.
first
])[
ritr
.
first
]
-=
ritr
.
second
;
}
}
}
return
Status
::
Success
;
}
RateLimiter
::
ResourceManager
::
ResourceManager
(
const
ResourceMap
&
resource_map
)
:
explicit_max_resources_
(
resource_map
)
{
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/rate_limiter.h
0 → 100644
View file @
374c78ca
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <vector>
#include "backend_model.h"
#include "backend_model_instance.h"
#include "instance_queue.h"
#include "model_config.pb.h"
#include "payload.h"
#include "status.h"
namespace
triton
{
namespace
core
{
// Limits the rate at which requests are dispatched to the model instances
class
RateLimiter
{
public:
using
RateLimiterConfig
=
inference
::
ModelRateLimiter
;
using
ResourceMap
=
std
::
map
<
int
,
std
::
map
<
std
::
string
,
size_t
>>
;
enum
RESOURCE_KIND_KEY
{
// Key for holding global resources
GLOBAL_RESOURCE_KEY
=
-
2
,
// Key for holding resources per each device
PER_DEVICE_RESOURCE_KEY
=
-
1
};
/// Creates a rate limiter object which will funnel the requests to
/// the model instances. A typical lifetime of the model instance within
/// RateLimiter transition from available -> staged -> allocated -> available.
/// The transition from available to staged occurs when a request is
/// registered for the model. Depending upon the resource availabilty and
/// priority, the RateLimiter will transition an instance to allocated state
/// at some point in the future. The staged state is skipped when
/// configured to ignore the resource constraints. The cycle in this case
/// will be available -> allocated -> available.
/// \param ignore_resources_and_priority Whether or not to ignore resource
/// constraints and cross-model priority. An available instance is directly
/// allocated when true.
/// \param resource_map The map to the available resource count provided
/// explicitly.
/// \return Status object indicating success or failure.
static
Status
Create
(
const
bool
ignore_resources_and_priority
,
const
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
RateLimiter
>*
rate_limiter
);
/// Registers the model instance with the rate limiter.
/// \param instance The pointer to the TritonModelInstance object to register
/// with the rate limiter.
/// \param rate_limiter_config The rate limiter configuration associated with
/// the model instance.
/// \return Status object indicating success or failure.
Status
RegisterModelInstance
(
TritonModelInstance
*
instance
,
const
RateLimiterConfig
&
rate_limiter_config
);
/// Remove model from the set of models being managed by the rate limiter.
/// \param model The pointer to TritonModel object to be removed.
/// \return Status object indicating success or failure.
Status
UnregisterModel
(
const
TritonModel
*
model
);
/// Returns true if there is a payload slot available for the given model.
/// \param model The pointer to TritonModel object to be removed.
/// \return slot availability in boolean.
bool
PayloadSlotAvailable
(
const
TritonModel
*
model
);
/// Enqueues the payload to rate limiter for scheduling on the given model.
/// \param model The pointer to TritonModel object to be removed.
/// \param payload The shared pointer to the payload object.
/// \return Status object indicating success or failure.
Status
EnqueuePayload
(
const
TritonModel
*
model
,
std
::
shared_ptr
<
Payload
>
payload
);
/// Returns the payload that has been scheduled for the given set of model
/// instances. Note that this call is blocking and depends upon the
/// availability of payloads in the rate limiter for the triton model
/// instance.
/// \param instance The pointers to TritonModelInstance objects whose
/// payload is being requested.
/// \param payload The shared pointer to the payload object.
void
DequeuePayload
(
std
::
deque
<
TritonModelInstance
*>&
instance
,
std
::
shared_ptr
<
Payload
>*
payload
);
/// Returns a new payload object.
/// \param op_type The operation type for the payload.
/// \param instance Optional field that providess the model instance that must
/// be used for the execution of the payload. Default is nullptr which allows
/// any model instance to execute the payload.
/// \return The shared pointer to a new payload object.
std
::
shared_ptr
<
Payload
>
GetPayload
(
const
Payload
::
Operation
op_type
,
TritonModelInstance
*
instance
=
nullptr
);
/// Releases the given payload object back to the rate limiter.
/// \param payload The payload to release.
void
PayloadRelease
(
std
::
shared_ptr
<
Payload
>&
payload
);
private:
class
ModelInstanceContext
;
class
ModelContext
;
struct
PayloadQueue
;
using
StandardReleaseFunc
=
std
::
function
<
void
(
ModelInstanceContext
*
)
>
;
using
StandardScheduleFunc
=
std
::
function
<
void
(
ModelInstanceContext
*
)
>
;
using
StandardStageFunc
=
std
::
function
<
void
(
ModelInstanceContext
*
)
>
;
// Holds the state of the model instance.
class
ModelInstanceContext
{
public:
friend
class
RateLimiter
;
friend
class
ResourceManager
;
enum
State
{
AVAILABLE
,
STAGED
,
ALLOCATED
,
REMOVED
};
void
Release
();
TritonModelInstance
*
RawInstance
()
const
{
return
triton_model_instance_
;
}
private:
ModelInstanceContext
(
TritonModelInstance
*
triton_model_instance
,
ModelContext
*
model_context
,
const
RateLimiterConfig
&
rate_limiter_config
,
StandardStageFunc
OnStage
,
StandardReleaseFunc
OnRelease
);
const
RateLimiterConfig
*
GetRateLimiterConfig
()
const
{
return
&
rate_limiter_config_
;
}
void
MarkAvailable
();
double
ScaledPriority
();
Status
Stage
(
StandardScheduleFunc
OnSchedule
);
Status
Allocate
();
Status
DirectAllocate
(
StandardScheduleFunc
OnSchedule
);
void
RequestRemoval
();
void
WaitForRemoval
();
TritonModelInstance
*
triton_model_instance_
;
size_t
index_
;
ModelContext
*
model_context_
;
RateLimiterConfig
rate_limiter_config_
;
StandardStageFunc
OnStage_
;
StandardReleaseFunc
OnRelease_
;
std
::
atomic
<
uint64_t
>
exec_count_
;
State
state_
;
bool
removal_in_progress_
;
std
::
mutex
state_mtx_
;
StandardScheduleFunc
OnSchedule_
;
std
::
condition_variable
cv_
;
};
class
ScaledPriorityComparator
{
public:
bool
operator
()(
ModelInstanceContext
*
a
,
ModelInstanceContext
*
b
)
{
return
a
->
ScaledPriority
()
>
b
->
ScaledPriority
();
}
};
using
PriorityQueue
=
std
::
priority_queue
<
ModelInstanceContext
*
,
std
::
vector
<
ModelInstanceContext
*>
,
ScaledPriorityComparator
>
;
// Holds the active context to a model
class
ModelContext
{
public:
ModelContext
();
Status
EnqueueModelInstanceRequest
(
const
StandardScheduleFunc
&
OnSchedule
,
TritonModelInstance
*
triton_model_instance
);
void
AddAvailableInstance
(
ModelInstanceContext
*
instance
);
void
StageInstanceIfAvailable
(
TritonModelInstance
*
triton_model_instance
);
void
AllocateInstanceIfAvailable
();
void
AddSpecificRequestQueue
();
bool
ContainsPendingRequests
(
int32_t
index
);
void
RequestRemoval
();
bool
isRemovalInProgress
()
{
return
removal_in_progress_
;
}
private:
bool
removal_in_progress_
;
// Queue holding pending scheduling request
std
::
queue
<
StandardScheduleFunc
>
generic_sched_request_queue_
;
std
::
vector
<
std
::
queue
<
StandardScheduleFunc
>>
specific_sched_request_queues_
;
std
::
recursive_mutex
sched_request_queue_mtx_
;
// The set of instances that are available at the moment
PriorityQueue
avbl_instances_
;
std
::
recursive_mutex
avbl_instances_mtx_
;
};
// Manages and keep track of resource allocation to the model instances.
class
ResourceManager
{
public:
static
Status
Create
(
const
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
ResourceManager
>*
resource_manager
);
void
AddModelInstance
(
const
ModelInstanceContext
*
instance
);
Status
RemoveModelInstance
(
const
ModelInstanceContext
*
instance
);
Status
UpdateResourceLimits
();
bool
AllocateResources
(
const
ModelInstanceContext
*
instance
);
Status
ReleaseResources
(
const
ModelInstanceContext
*
instance
);
private:
ResourceManager
(
const
ResourceMap
&
resource_map
);
Status
ValidateMaxResources
();
Status
ParseAndValidateExplicitResources
();
ResourceMap
explicit_max_resources_
;
std
::
map
<
const
ModelInstanceContext
*
,
ResourceMap
>
model_resources_
;
std
::
mutex
model_resources_mtx_
;
ResourceMap
max_resources_
;
std
::
mutex
max_resources_mtx_
;
ResourceMap
allocated_resources_
;
std
::
mutex
allocated_resources_mtx_
;
};
RateLimiter
(
const
bool
ignore_resources_and_priority
,
const
ResourceMap
&
resource_map
);
void
InitializePayloadQueues
(
const
TritonModelInstance
*
instance
);
Status
DeferPayloadSchedule
(
const
StandardScheduleFunc
&
OnSchedule
,
const
TritonModel
*
model
,
TritonModelInstance
*
instance
=
nullptr
);
void
OnStage
(
ModelInstanceContext
*
instance_ptr
);
void
OnRelease
(
ModelInstanceContext
*
instance_ptr
);
void
AttemptAllocation
();
void
SchedulePayload
(
TritonModelInstance
*
tmi
,
PayloadQueue
*
payload_queue
,
const
std
::
shared_ptr
<
Payload
>&
payload
);
bool
ignore_resources_and_priority_
;
// Instance context for the models
std
::
map
<
const
TritonModel
*
,
std
::
vector
<
std
::
shared_ptr
<
ModelInstanceContext
>>>
model_instance_ctxs_
;
std
::
mutex
model_instance_ctx_mtx_
;
// Running context of the models
std
::
map
<
const
TritonModel
*
,
ModelContext
>
model_contexts_
;
std
::
mutex
model_ctx_mtx_
;
// Holds the model instances that have been staged
PriorityQueue
staged_instances_
;
std
::
recursive_mutex
staged_instances_mtx_
;
// Manager to keep track of the resource allocations
std
::
unique_ptr
<
ResourceManager
>
resource_manager_
;
// Mutex to serialize Payload [de]allocation
std
::
mutex
payload_mu_
;
// Mutex to serialize Payload Queues deallocation
std
::
mutex
payload_queues_mu_
;
// Keep some number of Payload objects for reuse to avoid the overhead
// of creating a Payload for every new request.
const
size_t
max_payload_bucket_count_
;
std
::
vector
<
std
::
shared_ptr
<
Payload
>>
payload_bucket_
;
std
::
deque
<
std
::
shared_ptr
<
Payload
>>
payloads_in_use_
;
struct
PayloadQueue
{
explicit
PayloadQueue
(
size_t
max_batch_size
,
uint64_t
max_queue_delay_ns
)
{
queue_
.
reset
(
new
InstanceQueue
(
max_batch_size
,
max_queue_delay_ns
));
}
std
::
unique_ptr
<
InstanceQueue
>
queue_
;
std
::
map
<
const
TritonModelInstance
*
,
std
::
unique_ptr
<
InstanceQueue
>>
specific_queues_
;
std
::
mutex
mu_
;
std
::
condition_variable
cv_
;
};
std
::
map
<
const
TritonModel
*
,
std
::
unique_ptr
<
PayloadQueue
>>
payload_queues_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/repo_agent.cc
0 → 100644
View file @
374c78ca
// Copyright 2021-2022, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "repo_agent.h"
#include <string>
#include "filesystem.h"
#include "shared_library.h"
#include "triton/common/logging.h"
#include "tritonserver_apis.h"
// For unknown reason, windows will not export the TRITONREPOAGENT_*
// functions declared with dllexport in tritonrepoagent.h. To get
// those functions exported it is (also?) necessary to mark the
// definitions in this file with dllexport as well.
#if defined(_MSC_VER)
#define TRITONAPI_DECLSPEC __declspec(dllexport)
#elif defined(__GNUC__)
#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default")))
#else
#define TRITONAPI_DECLSPEC
#endif
namespace
triton
{
namespace
core
{
std
::
string
TritonRepoAgentLibraryName
(
const
std
::
string
&
agent_name
)
{
#ifdef _WIN32
return
std
::
string
(
"tritonrepoagent_"
)
+
agent_name
+
".dll"
;
#else
return
std
::
string
(
"libtritonrepoagent_"
)
+
agent_name
+
".so"
;
#endif
}
std
::
string
TRITONREPOAGENT_ActionTypeString
(
const
TRITONREPOAGENT_ActionType
type
)
{
switch
(
type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
return
"TRITONREPOAGENT_ACTION_LOAD"
;
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
return
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE"
;
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
return
"TRITONREPOAGENT_ACTION_LOAD_FAIL"
;
case
TRITONREPOAGENT_ACTION_UNLOAD
:
return
"TRITONREPOAGENT_ACTION_UNLOAD"
;
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
return
"TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"
;
}
return
"Unknown TRITONREPOAGENT_ActionType"
;
}
std
::
string
TRITONREPOAGENT_ArtifactTypeString
(
const
TRITONREPOAGENT_ArtifactType
type
)
{
switch
(
type
)
{
case
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
:
return
"TRITONREPOAGENT_ARTIFACT_FILESYSTEM"
;
case
TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM
:
return
"TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM"
;
}
return
"Unknown TRITONREPOAGENT_ArtifactType"
;
}
//
// TritonRepoAgent
//
Status
TritonRepoAgent
::
Create
(
const
std
::
string
&
name
,
const
std
::
string
&
libpath
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
)
{
std
::
shared_ptr
<
TritonRepoAgent
>
lagent
(
new
TritonRepoAgent
(
name
));
{
std
::
unique_ptr
<
SharedLibrary
>
slib
;
RETURN_IF_ERROR
(
SharedLibrary
::
Acquire
(
&
slib
));
RETURN_IF_ERROR
(
slib
->
OpenLibraryHandle
(
libpath
,
&
lagent
->
dlhandle_
));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_Initialize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
init_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_Finalize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
fini_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_ModelInitialize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
model_init_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_ModelFinalize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
model_fini_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_ModelAction"
,
false
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
model_action_fn_
)));
}
// Initialize if needed
if
(
lagent
->
init_fn_
!=
nullptr
)
{
RETURN_IF_TRITONSERVER_ERROR
(
lagent
->
init_fn_
(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
lagent
.
get
())));
}
*
agent
=
std
::
move
(
lagent
);
return
Status
::
Success
;
}
TritonRepoAgent
::~
TritonRepoAgent
()
{
// Finalize if needed
if
(
fini_fn_
!=
nullptr
)
{
auto
err
=
fini_fn_
(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
this
));
if
(
err
!=
nullptr
)
{
LOG_ERROR
<<
"~TritonRepoAgent: "
<<
Status
(
TritonCodeToStatusCode
(
TRITONSERVER_ErrorCode
(
err
)),
TRITONSERVER_ErrorMessage
(
err
))
.
AsString
();
TRITONSERVER_ErrorDelete
(
err
);
};
}
{
std
::
unique_ptr
<
SharedLibrary
>
slib
;
LOG_STATUS_ERROR
(
SharedLibrary
::
Acquire
(
&
slib
),
"~TritonRepoAgent"
);
LOG_STATUS_ERROR
(
slib
->
CloseLibraryHandle
(
dlhandle_
),
"~TritonRepoAgent"
);
}
}
//
// TritonRepoAgentModel
//
Status
TritonRepoAgentModel
::
Create
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
,
const
inference
::
ModelConfig
&
config
,
const
std
::
shared_ptr
<
TritonRepoAgent
>&
agent
,
const
TritonRepoAgent
::
Parameters
&
agent_parameters
,
std
::
unique_ptr
<
TritonRepoAgentModel
>*
agent_model
)
{
std
::
unique_ptr
<
TritonRepoAgentModel
>
lagent_model
(
new
TritonRepoAgentModel
(
type
,
location
,
config
,
agent
,
agent_parameters
));
if
(
agent
->
AgentModelInitFn
()
!=
nullptr
)
{
RETURN_IF_TRITONSERVER_ERROR
(
agent
->
AgentModelInitFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
lagent_model
.
get
())));
}
*
agent_model
=
std
::
move
(
lagent_model
);
return
Status
::
Success
;
}
TritonRepoAgentModel
::~
TritonRepoAgentModel
()
{
// Need to ensure the proper lifecycle is informed
if
(
action_type_set_
)
{
switch
(
current_action_type_
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_LOAD_FAIL
),
"Inform TRITONREPOAGENT_ACTION_LOAD_FAIL"
);
break
;
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_UNLOAD
),
"Inform TRITONREPOAGENT_ACTION_UNLOAD"
);
// Fallthough is not yet an language feature until C++17
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
),
"Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"
);
break
;
case
TRITONREPOAGENT_ACTION_UNLOAD
:
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
),
"Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"
);
break
;
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
break
;
}
}
if
(
agent_
->
AgentModelFiniFn
()
!=
nullptr
)
{
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelFiniFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
)),
"~TritonRepoAgentModel"
);
}
if
(
!
acquired_location_
.
empty
())
{
DeleteMutableLocation
();
}
}
Status
TritonRepoAgentModel
::
InvokeAgent
(
const
TRITONREPOAGENT_ActionType
action_type
)
{
if
((
!
action_type_set_
)
&&
(
action_type
!=
TRITONREPOAGENT_ACTION_LOAD
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle start state "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
if
(
action_type_set_
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_LOAD
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
case
TRITONREPOAGENT_ACTION_UNLOAD
:
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_UNLOAD
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
}
current_action_type_
=
action_type
;
action_type_set_
=
true
;
RETURN_IF_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
action_type
));
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
SetLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
)
{
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_LOAD
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"location can only be updated during TRITONREPOAGENT_ACTION_LOAD, "
"current action type is "
+
(
action_type_set_
?
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
:
"not set"
));
}
type_
=
type
;
location_
=
location
;
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
Location
(
TRITONREPOAGENT_ArtifactType
*
type
,
const
char
**
location
)
{
if
(
location_
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Model repository location is not set"
);
}
*
type
=
type_
;
*
location
=
location_
.
c_str
();
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
AcquireMutableLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
char
**
location
)
{
if
(
type
!=
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Unexpected artifact type, expects "
"'TRITONREPOAGENT_ARTIFACT_FILESYSTEM'"
);
}
if
(
acquired_location_
.
empty
())
{
std
::
string
lacquired_location
;
RETURN_IF_ERROR
(
MakeTemporaryDirectory
(
FileSystemType
::
LOCAL
,
&
lacquired_location
));
acquired_location_
.
swap
(
lacquired_location
);
acquired_type_
=
type
;
}
*
location
=
acquired_location_
.
c_str
();
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
DeleteMutableLocation
()
{
if
(
acquired_location_
.
empty
())
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"No mutable location to be deleted"
);
}
auto
status
=
DeletePath
(
acquired_location_
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to delete previously acquired location '"
<<
acquired_location_
<<
"': "
<<
status
.
AsString
();
}
acquired_location_
.
clear
();
return
Status
::
Success
;
}
//
// TritonRepoAgentManager
//
TritonRepoAgentManager
&
TritonRepoAgentManager
::
Singleton
()
{
static
TritonRepoAgentManager
triton_repo_agent_manager
;
return
triton_repo_agent_manager
;
}
Status
TritonRepoAgentManager
::
SetGlobalSearchPath
(
const
std
::
string
&
path
)
{
auto
&
singleton_manager
=
Singleton
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton_manager
.
mu_
);
singleton_manager
.
global_search_path_
=
path
;
return
Status
::
Success
;
}
Status
TritonRepoAgentManager
::
CreateAgent
(
const
std
::
string
&
agent_name
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
)
{
auto
&
singleton_manager
=
Singleton
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton_manager
.
mu_
);
// Get the path to the agent shared library. Search path is global
// agent directory. FIXME expose global path as Triton option
const
std
::
vector
<
std
::
string
>
search_paths
=
{
JoinPath
({
singleton_manager
.
global_search_path_
,
agent_name
})};
std
::
string
agent_libname
=
TritonRepoAgentLibraryName
(
agent_name
);
std
::
string
libpath
;
for
(
const
auto
&
path
:
search_paths
)
{
const
auto
full_path
=
JoinPath
({
path
,
agent_libname
});
bool
exists
=
false
;
RETURN_IF_ERROR
(
FileExists
(
full_path
,
&
exists
));
if
(
exists
)
{
libpath
=
full_path
;
break
;
}
}
if
(
libpath
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unable to find '"
+
agent_libname
+
"' for repo agent '"
+
agent_name
+
"', searched: "
+
singleton_manager
.
global_search_path_
);
}
const
auto
&
itr
=
singleton_manager
.
agent_map_
.
find
(
libpath
);
if
(
itr
!=
singleton_manager
.
agent_map_
.
end
())
{
// Found in map. If the weak_ptr is still valid that means that
// there are other models using the agent and we just reuse that
// same agent. If the weak_ptr is not valid then agent has been
// unloaded so we need to remove the weak_ptr from the map and
// create the agent again.
*
agent
=
itr
->
second
.
lock
();
if
(
*
agent
!=
nullptr
)
{
return
Status
::
Success
;
}
singleton_manager
.
agent_map_
.
erase
(
itr
);
}
RETURN_IF_ERROR
(
TritonRepoAgent
::
Create
(
agent_name
,
libpath
,
agent
));
singleton_manager
.
agent_map_
.
insert
({
libpath
,
*
agent
});
return
Status
::
Success
;
}
Status
TritonRepoAgentManager
::
AgentState
(
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>*
agent_state
)
{
auto
&
singleton_manager
=
Singleton
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton_manager
.
mu_
);
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
agent_state_map
(
new
std
::
unordered_map
<
std
::
string
,
std
::
string
>
);
for
(
const
auto
&
agent_pair
:
singleton_manager
.
agent_map_
)
{
auto
&
libpath
=
agent_pair
.
first
;
auto
agent
=
agent_pair
.
second
.
lock
();
if
(
agent
!=
nullptr
)
{
agent_state_map
->
insert
({
agent
->
Name
(),
libpath
});
}
}
*
agent_state
=
std
::
move
(
agent_state_map
);
return
Status
::
Success
;
}
extern
"C"
{
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ApiVersion
(
uint32_t
*
major
,
uint32_t
*
minor
)
{
*
major
=
TRITONREPOAGENT_API_VERSION_MAJOR
;
*
minor
=
TRITONREPOAGENT_API_VERSION_MINOR
;
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryLocation
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
TRITONREPOAGENT_ArtifactType
*
artifact_type
,
const
char
**
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
Location
(
artifact_type
,
location
));
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryLocationAcquire
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ArtifactType
artifact_type
,
const
char
**
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
AcquireMutableLocation
(
artifact_type
,
location
));
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryLocationRelease
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
char
*
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
DeleteMutableLocation
());
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryUpdate
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ArtifactType
artifact_type
,
const
char
*
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
SetLocation
(
artifact_type
,
location
));
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelParameterCount
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
uint32_t
*
count
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
*
count
=
tam
->
AgentParameters
().
size
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelParameter
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
uint32_t
index
,
const
char
**
parameter_name
,
const
char
**
parameter_value
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
const
auto
&
params
=
tam
->
AgentParameters
();
if
(
index
>=
params
.
size
())
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"index out of range for model parameters"
);
}
*
parameter_name
=
params
[
index
].
first
.
c_str
();
*
parameter_value
=
params
[
index
].
second
.
c_str
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelConfig
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
uint32_t
config_version
,
TRITONSERVER_Message
**
model_config
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
std
::
string
model_config_json
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
ModelConfigToJson
(
tam
->
Config
(),
config_version
,
&
model_config_json
));
return
TRITONSERVER_MessageNewFromSerializedJson
(
model_config
,
model_config_json
.
c_str
(),
model_config_json
.
length
());
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelState
(
TRITONREPOAGENT_AgentModel
*
model
,
void
**
state
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
*
state
=
tam
->
State
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelSetState
(
TRITONREPOAGENT_AgentModel
*
model
,
void
*
state
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
tam
->
SetState
(
state
);
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_State
(
TRITONREPOAGENT_Agent
*
agent
,
void
**
state
)
{
TritonRepoAgent
*
ta
=
reinterpret_cast
<
TritonRepoAgent
*>
(
agent
);
*
state
=
ta
->
State
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_SetState
(
TRITONREPOAGENT_Agent
*
agent
,
void
*
state
)
{
TritonRepoAgent
*
ta
=
reinterpret_cast
<
TritonRepoAgent
*>
(
agent
);
ta
->
SetState
(
state
);
return
nullptr
;
// success
}
}
// extern C
}}
// namespace triton::core
3rdparty/core-r22.12/src/repo_agent.h
0 → 100644
View file @
374c78ca
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "tritonserver_apis.h"
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "constants.h"
#include "model_config_utils.h"
namespace
triton
{
namespace
core
{
std
::
string
TritonRepoAgentLibraryName
(
const
std
::
string
&
agent_name
);
std
::
string
TRITONREPOAGENT_ActionTypeString
(
const
TRITONREPOAGENT_ActionType
type
);
std
::
string
TRITONREPOAGENT_ArtifactTypeString
(
const
TRITONREPOAGENT_ArtifactType
type
);
class
TritonRepoAgent
{
public:
using
Parameters
=
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
;
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentInitFn_t
)(
TRITONREPOAGENT_Agent
*
agent
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentFiniFn_t
)(
TRITONREPOAGENT_Agent
*
agent
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentModelInitFn_t
)(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentModelFiniFn_t
)(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentModelActionFn_t
)(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ActionType
action_type
);
static
Status
Create
(
const
std
::
string
&
name
,
const
std
::
string
&
libpath
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
);
~
TritonRepoAgent
();
const
std
::
string
&
Name
()
{
return
name_
;
}
void
*
State
()
{
return
state_
;
}
void
SetState
(
void
*
state
)
{
state_
=
state
;
}
TritonRepoAgentModelActionFn_t
AgentModelActionFn
()
const
{
return
model_action_fn_
;
}
TritonRepoAgentModelInitFn_t
AgentModelInitFn
()
const
{
return
model_init_fn_
;
}
TritonRepoAgentModelFiniFn_t
AgentModelFiniFn
()
const
{
return
model_fini_fn_
;
}
protected:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgent
);
TritonRepoAgent
(
const
std
::
string
&
name
)
:
name_
(
name
),
state_
(
nullptr
),
dlhandle_
(
nullptr
),
init_fn_
(
nullptr
),
fini_fn_
(
nullptr
),
model_init_fn_
(
nullptr
),
model_fini_fn_
(
nullptr
),
model_action_fn_
(
nullptr
)
{
}
const
std
::
string
name_
;
void
*
state_
;
// dlopen / dlsym handles
void
*
dlhandle_
;
TritonRepoAgentInitFn_t
init_fn_
;
TritonRepoAgentFiniFn_t
fini_fn_
;
TritonRepoAgentModelInitFn_t
model_init_fn_
;
TritonRepoAgentModelFiniFn_t
model_fini_fn_
;
TritonRepoAgentModelActionFn_t
model_action_fn_
;
};
class
TritonRepoAgentModel
{
public:
static
Status
Create
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
,
const
inference
::
ModelConfig
&
config
,
const
std
::
shared_ptr
<
TritonRepoAgent
>&
agent
,
const
TritonRepoAgent
::
Parameters
&
agent_parameters
,
std
::
unique_ptr
<
TritonRepoAgentModel
>*
agent_model
);
~
TritonRepoAgentModel
();
void
*
State
()
{
return
state_
;
}
void
SetState
(
void
*
state
)
{
state_
=
state
;
}
Status
InvokeAgent
(
const
TRITONREPOAGENT_ActionType
action_type
);
const
TritonRepoAgent
::
Parameters
&
AgentParameters
()
{
return
agent_parameters_
;
}
Status
SetLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
);
Status
Location
(
TRITONREPOAGENT_ArtifactType
*
type
,
const
char
**
location
);
Status
AcquireMutableLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
char
**
location
);
Status
DeleteMutableLocation
();
const
inference
::
ModelConfig
Config
()
{
return
config_
;
}
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentModel
);
TritonRepoAgentModel
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
,
const
inference
::
ModelConfig
&
config
,
const
std
::
shared_ptr
<
TritonRepoAgent
>&
agent
,
const
TritonRepoAgent
::
Parameters
&
agent_parameters
)
:
state_
(
nullptr
),
config_
(
config
),
agent_
(
agent
),
agent_parameters_
(
agent_parameters
),
type_
(
type
),
location_
(
location
),
action_type_set_
(
false
),
current_action_type_
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
)
{
}
void
*
state_
;
const
inference
::
ModelConfig
config_
;
const
std
::
shared_ptr
<
TritonRepoAgent
>
agent_
;
const
TritonRepoAgent
::
Parameters
agent_parameters_
;
TRITONREPOAGENT_ArtifactType
type_
;
std
::
string
location_
;
TRITONREPOAGENT_ArtifactType
acquired_type_
;
std
::
string
acquired_location_
;
bool
action_type_set_
;
TRITONREPOAGENT_ActionType
current_action_type_
;
};
class
TritonRepoAgentManager
{
public:
static
Status
SetGlobalSearchPath
(
const
std
::
string
&
path
);
static
Status
CreateAgent
(
const
std
::
string
&
agent_name
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
);
static
Status
AgentState
(
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>*
agent_state
);
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentManager
);
TritonRepoAgentManager
()
:
global_search_path_
(
"/opt/tritonserver/repoagents"
){};
static
TritonRepoAgentManager
&
Singleton
();
std
::
mutex
mu_
;
std
::
string
global_search_path_
;
std
::
unordered_map
<
std
::
string
,
std
::
weak_ptr
<
TritonRepoAgent
>>
agent_map_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/response_allocator.h
0 → 100644
View file @
374c78ca
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
//
// Implementation for TRITONSERVER_ResponseAllocator.
//
class
ResponseAllocator
{
public:
explicit
ResponseAllocator
(
TRITONSERVER_ResponseAllocatorAllocFn_t
alloc_fn
,
TRITONSERVER_ResponseAllocatorReleaseFn_t
release_fn
,
TRITONSERVER_ResponseAllocatorStartFn_t
start_fn
)
:
alloc_fn_
(
alloc_fn
),
buffer_attributes_fn_
(
nullptr
),
query_fn_
(
nullptr
),
release_fn_
(
release_fn
),
start_fn_
(
start_fn
)
{
}
void
SetQueryFunction
(
TRITONSERVER_ResponseAllocatorQueryFn_t
query_fn
)
{
query_fn_
=
query_fn
;
}
void
SetBufferAttributesFunction
(
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t
buffer_attributes_fn
)
{
buffer_attributes_fn_
=
buffer_attributes_fn
;
}
TRITONSERVER_ResponseAllocatorAllocFn_t
AllocFn
()
const
{
return
alloc_fn_
;
}
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t
BufferAttributesFn
()
const
{
return
buffer_attributes_fn_
;
}
TRITONSERVER_ResponseAllocatorQueryFn_t
QueryFn
()
const
{
return
query_fn_
;
}
TRITONSERVER_ResponseAllocatorReleaseFn_t
ReleaseFn
()
const
{
return
release_fn_
;
}
TRITONSERVER_ResponseAllocatorStartFn_t
StartFn
()
const
{
return
start_fn_
;
}
private:
TRITONSERVER_ResponseAllocatorAllocFn_t
alloc_fn_
;
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t
buffer_attributes_fn_
;
TRITONSERVER_ResponseAllocatorQueryFn_t
query_fn_
;
TRITONSERVER_ResponseAllocatorReleaseFn_t
release_fn_
;
TRITONSERVER_ResponseAllocatorStartFn_t
start_fn_
;
};
}}
// namespace triton::core
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment