Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
0a21fff9
Commit
0a21fff9
authored
Dec 20, 2023
by
xiabo
Browse files
Adapt to 0.1.0
parent
9484fd1c
Changes
158
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
7013 additions
and
0 deletions
+7013
-0
3rdparty/core-r22.12/src/model_lifecycle.cc
3rdparty/core-r22.12/src/model_lifecycle.cc
+740
-0
3rdparty/core-r22.12/src/model_lifecycle.h
3rdparty/core-r22.12/src/model_lifecycle.h
+324
-0
3rdparty/core-r22.12/src/model_repository_manager.cc
3rdparty/core-r22.12/src/model_repository_manager.cc
+1602
-0
3rdparty/core-r22.12/src/model_repository_manager.h
3rdparty/core-r22.12/src/model_repository_manager.h
+345
-0
3rdparty/core-r22.12/src/numa_utils.cc
3rdparty/core-r22.12/src/numa_utils.cc
+237
-0
3rdparty/core-r22.12/src/numa_utils.h
3rdparty/core-r22.12/src/numa_utils.h
+57
-0
3rdparty/core-r22.12/src/payload.cc
3rdparty/core-r22.12/src/payload.cc
+215
-0
3rdparty/core-r22.12/src/payload.h
3rdparty/core-r22.12/src/payload.h
+102
-0
3rdparty/core-r22.12/src/pinned_memory_manager.cc
3rdparty/core-r22.12/src/pinned_memory_manager.cc
+378
-0
3rdparty/core-r22.12/src/pinned_memory_manager.h
3rdparty/core-r22.12/src/pinned_memory_manager.h
+108
-0
3rdparty/core-r22.12/src/rate_limiter.cc
3rdparty/core-r22.12/src/rate_limiter.cc
+943
-0
3rdparty/core-r22.12/src/rate_limiter.h
3rdparty/core-r22.12/src/rate_limiter.h
+310
-0
3rdparty/core-r22.12/src/repo_agent.cc
3rdparty/core-r22.12/src/repo_agent.cc
+573
-0
3rdparty/core-r22.12/src/repo_agent.h
3rdparty/core-r22.12/src/repo_agent.h
+182
-0
3rdparty/core-r22.12/src/response_allocator.h
3rdparty/core-r22.12/src/response_allocator.h
+77
-0
3rdparty/core-r22.12/src/response_cache.cc
3rdparty/core-r22.12/src/response_cache.cc
+542
-0
3rdparty/core-r22.12/src/response_cache.h
3rdparty/core-r22.12/src/response_cache.h
+198
-0
3rdparty/core-r22.12/src/scheduler.h
3rdparty/core-r22.12/src/scheduler.h
+80
-0
No files found.
Too many changes to show.
To preserve performance only
158 of 158+
files are displayed.
Plain diff
Email patch
3rdparty/core-r22.12/src/model_lifecycle.cc
0 → 100644
View file @
0a21fff9
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_lifecycle.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "repo_agent.h"
#include "triton/common/logging.h"
#include "triton/common/thread_pool.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace
triton
{
namespace
core
{
const
std
::
string
&
ModelReadyStateString
(
ModelReadyState
state
)
{
switch
(
state
)
{
case
ModelReadyState
::
UNKNOWN
:
{
static
std
::
string
m
(
"UNKNOWN"
);
return
m
;
}
case
ModelReadyState
::
READY
:
{
static
std
::
string
m
(
"READY"
);
return
m
;
}
case
ModelReadyState
::
UNAVAILABLE
:
{
static
std
::
string
m
(
"UNAVAILABLE"
);
return
m
;
}
case
ModelReadyState
::
LOADING
:
{
static
std
::
string
m
(
"LOADING"
);
return
m
;
}
case
ModelReadyState
::
UNLOADING
:
{
static
std
::
string
m
(
"UNLOADING"
);
return
m
;
}
}
static
std
::
string
m
(
"<unknown>"
);
return
m
;
}
namespace
{
Status
VersionsToLoad
(
const
std
::
string
model_path
,
const
std
::
string
&
name
,
const
inference
::
ModelConfig
&
model_config
,
std
::
set
<
int64_t
>*
versions
)
{
versions
->
clear
();
// Get integral number of the version directory
std
::
set
<
std
::
string
>
subdirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
model_path
,
&
subdirs
));
std
::
set
<
int64_t
,
std
::
greater
<
int64_t
>>
existing_versions
;
for
(
const
auto
&
subdir
:
subdirs
)
{
if
(
subdir
==
kWarmupDataFolder
||
subdir
==
kInitialStateFolder
)
{
continue
;
}
if
((
subdir
.
length
()
>
1
)
&&
(
subdir
.
front
()
==
'0'
))
{
LOG_WARNING
<<
"ignore version directory '"
<<
subdir
<<
"' which contains leading zeros in its directory name"
;
continue
;
}
try
{
int64_t
version
=
std
::
stoll
(
subdir
);
existing_versions
.
insert
(
version
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
LOG_WARNING
<<
"ignore version directory '"
<<
subdir
<<
"' which fails to convert to integral number"
;
}
}
if
(
model_config
.
version_policy
().
has_specific
())
{
for
(
const
auto
&
v
:
model_config
.
version_policy
().
specific
().
versions
())
{
// Only load the specific versions that are presented in model directory
bool
version_not_exist
=
existing_versions
.
insert
(
v
).
second
;
if
(
!
version_not_exist
)
{
versions
->
emplace
(
v
);
}
else
{
LOG_ERROR
<<
"version "
<<
v
<<
" is specified for model '"
<<
name
<<
"', but the version directory is not present"
;
}
}
}
else
{
if
(
model_config
.
version_policy
().
has_latest
())
{
// std::set is sorted with std::greater
for
(
const
auto
&
v
:
existing_versions
)
{
if
(
versions
->
size
()
>=
model_config
.
version_policy
().
latest
().
num_versions
())
{
break
;
}
versions
->
emplace
(
v
);
}
}
else
{
// all
versions
->
insert
(
existing_versions
.
begin
(),
existing_versions
.
end
());
}
}
return
Status
::
Success
;
}
// Use smart pointer with custom deleter so that model state will be updated
// to UNAVAILABLE if all smart pointer copies are out of scope
struct
ModelDeleter
{
ModelDeleter
(
std
::
function
<
void
()
>
OnDestroyModel
)
:
OnDestroyModel_
(
std
::
move
(
OnDestroyModel
))
{
}
void
operator
()(
Model
*
model
)
{
// The actual model object must be destroyed in a different
// thread. This thread could have a callstack that includes the
// model itself because this deleter could be triggered by
// a request release or response send in the model. Following
// delete will lead to the model destructor which may wait on this
// same thread... so deadlock if we don't use a different thread
// here.
std
::
function
<
void
()
>
destroy_fn
=
OnDestroyModel_
;
std
::
thread
dthd
([
model
,
destroy_fn
]()
{
delete
model
;
destroy_fn
();
});
dthd
.
detach
();
}
// Use to inform the ModelLifeCycle that the model handle is destroyed
std
::
function
<
void
()
>
OnDestroyModel_
;
};
}
// namespace
Status
ModelLifeCycle
::
Create
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
,
std
::
unique_ptr
<
ModelLifeCycle
>*
life_cycle
)
{
std
::
unique_ptr
<
ModelLifeCycle
>
local_life_cycle
(
new
ModelLifeCycle
(
server
,
options
));
*
life_cycle
=
std
::
move
(
local_life_cycle
);
return
Status
::
Success
;
}
const
ModelStateMap
ModelLifeCycle
::
LiveModelStates
(
bool
strict_readiness
)
{
LOG_VERBOSE
(
2
)
<<
"LiveModelStates()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
ModelStateMap
live_model_states
;
for
(
auto
&
model_version
:
map_
)
{
bool
live
=
false
;
VersionStateMap
version_map
;
for
(
auto
&
version_model
:
model_version
.
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
strict_readiness
&&
version_model
.
second
->
state_
!=
ModelReadyState
::
READY
)
{
continue
;
}
// At least one version is live (ready / loading / unloading)
if
((
version_model
.
second
->
state_
!=
ModelReadyState
::
UNKNOWN
)
&&
(
version_model
.
second
->
state_
!=
ModelReadyState
::
UNAVAILABLE
))
{
live
=
true
;
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
}
if
(
live
)
{
live_model_states
[
model_version
.
first
]
=
std
::
move
(
version_map
);
}
}
return
live_model_states
;
}
Status
ModelLifeCycle
::
StopAllModels
()
{
LOG_VERBOSE
(
2
)
<<
"StopAllModels()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
for
(
auto
&
model_version
:
map_
)
{
for
(
auto
&
version_model
:
model_version
.
second
)
{
if
(
version_model
.
second
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
model_
!=
nullptr
)
{
version_model
.
second
->
model_
->
Stop
();
}
}
}
}
return
Status
::
Success
;
}
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
ModelLifeCycle
::
InflightStatus
()
{
LOG_VERBOSE
(
2
)
<<
"InflightStatus()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
inflight_status
;
for
(
auto
&
model_version
:
map_
)
{
for
(
auto
&
version_model
:
model_version
.
second
)
{
if
(
version_model
.
second
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
model_
!=
nullptr
)
{
const
auto
cnt
=
version_model
.
second
->
model_
->
InflightInferenceCount
();
if
(
cnt
!=
0
)
{
inflight_status
.
emplace
(
model_version
.
first
,
version_model
.
first
,
cnt
);
}
}
}
}
}
return
inflight_status
;
}
const
ModelStateMap
ModelLifeCycle
::
ModelStates
()
{
LOG_VERBOSE
(
2
)
<<
"ModelStates()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
ModelStateMap
model_states
;
for
(
auto
&
model_version
:
map_
)
{
VersionStateMap
version_map
;
for
(
auto
&
version_model
:
model_version
.
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
model_states
[
model_version
.
first
]
=
std
::
move
(
version_map
);
}
return
model_states
;
}
const
VersionStateMap
ModelLifeCycle
::
VersionStates
(
const
std
::
string
&
model_name
)
{
LOG_VERBOSE
(
2
)
<<
"VersionStates() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
VersionStateMap
version_map
;
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
!=
map_
.
end
())
{
for
(
auto
&
version_model
:
mit
->
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
}
return
version_map
;
}
Status
ModelLifeCycle
::
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
)
{
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
!=
map_
.
end
())
{
auto
vit
=
mit
->
second
.
find
(
model_version
);
if
(
vit
!=
mit
->
second
.
end
())
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
vit
->
second
->
mtx_
);
*
state
=
vit
->
second
->
state_
;
return
Status
::
Success
;
}
}
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"model '"
+
model_name
+
"', version "
+
std
::
to_string
(
model_version
)
+
" is not found"
);
}
Status
ModelLifeCycle
::
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
std
::
shared_ptr
<
Model
>*
model
)
{
LOG_VERBOSE
(
2
)
<<
"GetModel() '"
<<
model_name
<<
"' version "
<<
version
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
==
map_
.
end
())
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' is not found"
);
}
auto
vit
=
mit
->
second
.
find
(
version
);
if
(
vit
==
mit
->
second
.
end
())
{
if
(
version
!=
-
1
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' version "
+
std
::
to_string
(
version
)
+
" is not found"
);
}
// The case where the request is asking for latest version
int64_t
latest
=
-
1
;
for
(
auto
&
version_model
:
mit
->
second
)
{
if
(
version_model
.
first
>
latest
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
state_
==
ModelReadyState
::
READY
)
{
latest
=
version_model
.
first
;
// Tedious, but have to set handle for any "latest" version
// at the moment to avoid edge case like the following:
// "versions : 1 3 2", version 3 is latest but is requested
// to be unloaded when the iterator is examining version 2,
// then 'model' will ensure version 3 is still valid
*
model
=
version_model
.
second
->
model_
;
}
}
}
if
(
latest
==
-
1
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' has no available versions"
);
}
}
else
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
vit
->
second
->
mtx_
);
if
(
vit
->
second
->
state_
==
ModelReadyState
::
READY
)
{
*
model
=
vit
->
second
->
model_
;
}
else
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"'"
+
model_name
+
"' version "
+
std
::
to_string
(
version
)
+
" is not at ready state"
);
}
}
return
Status
::
Success
;
}
Status
ModelLifeCycle
::
AsyncUnload
(
const
std
::
string
&
model_name
)
{
LOG_VERBOSE
(
2
)
<<
"AsyncUnload() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
if
(
it
==
map_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Model to be unloaded has not been served"
);
}
// Get the existing agent models and notify the unload action
const
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
for
(
auto
&
version
:
it
->
second
)
{
auto
&
model_info
=
version
.
second
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
model_info
->
last_update_ns_
=
now_ns
;
// Unload serving model, for model that is in LOADING state,
// the updated timestamp will be recognized that there is newer update
// on the model info and the load should be aborted
if
(
model_info
->
state_
==
ModelReadyState
::
READY
)
{
if
(
model_info
->
agent_model_list_
!=
nullptr
)
{
// Only log the error because the model should be unloaded regardless
auto
status
=
model_info
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on TRITONREPOAGENT_ACTION_UNLOAD: "
<<
status
.
AsString
();
}
}
// unload
model_info
->
Release
();
}
}
return
Status
::
Success
;
}
Status
ModelLifeCycle
::
AsyncLoad
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
std
::
shared_ptr
<
TritonRepoAgentModelList
>&
agent_model_list
,
std
::
function
<
void
(
Status
)
>&&
OnComplete
)
{
LOG_VERBOSE
(
2
)
<<
"AsyncLoad() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
if
(
it
==
map_
.
end
())
{
it
=
map_
.
emplace
(
std
::
make_pair
(
model_name
,
VersionMap
())).
first
;
}
std
::
set
<
int64_t
>
versions
;
RETURN_IF_ERROR
(
VersionsToLoad
(
model_path
,
model_name
,
model_config
,
&
versions
));
if
(
versions
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"at least one version must be available under the version policy of "
"model '"
+
model_name
+
"'"
);
}
const
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
std
::
shared_ptr
<
LoadTracker
>
load_tracker
(
new
LoadTracker
(
versions
.
size
(),
now_ns
));
for
(
const
auto
&
version
:
versions
)
{
std
::
unique_ptr
<
ModelInfo
>
linfo
(
new
ModelInfo
(
model_path
,
model_config
,
now_ns
));
ModelInfo
*
model_info
=
linfo
.
get
();
LOG_INFO
<<
"loading: "
<<
model_name
<<
":"
<<
version
;
model_info
->
state_
=
ModelReadyState
::
LOADING
;
model_info
->
state_reason_
.
clear
();
model_info
->
agent_model_list_
=
agent_model_list
;
auto
res
=
it
->
second
.
emplace
(
std
::
make_pair
(
version
,
std
::
unique_ptr
<
ModelInfo
>
()));
if
(
res
.
second
)
{
res
.
first
->
second
=
std
::
move
(
linfo
);
}
else
{
// There is already a record of this model version. Check if the version
// model is being served, if so, the re-load of the version
// should be performed in background to avoid version downtime.
// Otherwise, swap and monitor state for newly loading model.
auto
&
serving_model
=
res
.
first
->
second
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
serving_model
->
mtx_
);
if
(
serving_model
->
state_
==
ModelReadyState
::
READY
)
{
background_models_
[(
uintptr_t
)
model_info
]
=
std
::
move
(
linfo
);
}
else
{
// swap the monitoring model info
serving_model
.
swap
(
linfo
);
// further check the state, put to 'background_models_' to keep
// the object valid if the model is LOADING / UNLOADING, because
// the model info will be accessed by a different thread once the
// operation is completed
if
((
linfo
->
state_
==
ModelReadyState
::
LOADING
)
||
(
linfo
->
state_
==
ModelReadyState
::
UNLOADING
))
{
ModelInfo
*
key
=
linfo
.
get
();
background_models_
[(
uintptr_t
)
key
]
=
std
::
move
(
linfo
);
}
}
}
// Load model asynchronously via thread pool
load_pool_
->
Enqueue
([
this
,
model_name
,
version
,
model_info
,
OnComplete
,
load_tracker
,
is_config_provided
]()
{
CreateModel
(
model_name
,
version
,
model_info
,
is_config_provided
);
OnLoadComplete
(
model_name
,
version
,
model_info
,
OnComplete
,
load_tracker
);
});
}
return
Status
::
Success
;
}
void
ModelLifeCycle
::
CreateModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
const
bool
is_config_provided
)
{
LOG_VERBOSE
(
2
)
<<
"CreateModel() '"
<<
model_name
<<
"' version "
<<
version
;
const
auto
&
model_config
=
model_info
->
model_config_
;
// Create model
Status
status
;
std
::
unique_ptr
<
Model
>
is
;
// If 'backend' is specified in the config then use the new triton
// backend.
if
(
!
model_config
.
backend
().
empty
())
{
std
::
unique_ptr
<
TritonModel
>
model
;
status
=
TritonModel
::
Create
(
server_
,
model_info
->
model_path_
,
cmdline_config_map_
,
host_policy_map_
,
model_name
,
version
,
model_config
,
is_config_provided
,
&
model
);
is
.
reset
(
model
.
release
());
}
else
{
#ifdef TRITON_ENABLE_ENSEMBLE
if
(
model_info
->
is_ensemble_
)
{
status
=
EnsembleModel
::
Create
(
server_
,
model_info
->
model_path_
,
version
,
model_config
,
is_config_provided
,
min_compute_capability_
,
&
is
);
// Complete label provider with label information from involved models
// Must be done here because involved models may not be able to
// obtained from server because this may happen during server
// initialization.
if
(
status
.
IsOk
())
{
std
::
set
<
std
::
string
>
no_label_outputs
;
const
auto
&
label_provider
=
is
->
GetLabelProvider
();
for
(
const
auto
&
output
:
model_config
.
output
())
{
if
(
label_provider
->
GetLabel
(
output
.
name
(),
0
).
empty
())
{
no_label_outputs
.
emplace
(
output
.
name
());
}
}
for
(
const
auto
&
element
:
model_config
.
ensemble_scheduling
().
step
())
{
for
(
const
auto
&
pair
:
element
.
output_map
())
{
// Found model that produce one of the missing output
if
(
no_label_outputs
.
find
(
pair
.
second
)
!=
no_label_outputs
.
end
())
{
std
::
shared_ptr
<
Model
>
model
;
// Safe to obtain model because the ensemble can't be loaded
// until the involved models are ready
GetModel
(
element
.
model_name
(),
element
.
model_version
(),
&
model
);
label_provider
->
AddLabels
(
pair
.
second
,
model
->
GetLabelProvider
()
->
GetLabels
(
pair
.
first
));
}
}
}
}
}
else
#endif // TRITON_ENABLE_ENSEMBLE
{
status
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown platform '"
+
model_config
.
platform
()
+
"'"
);
}
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
if
(
status
.
IsOk
())
{
// [FIXME] better way to manage agent model lifecycle
// Let the deleter also holds a shared pointer copy of agent model list,
// because the reference in ModelInfo can be cleared before the Model object
// is destroyed, and we want agent model to be valid for receiving
// UNLOAD_COMPLETE signal (see ~TritonRepoAgentModelList for detail)
auto
agent_model_list
=
model_info
->
agent_model_list_
;
model_info
->
model_
.
reset
(
is
.
release
(),
ModelDeleter
([
this
,
model_name
,
version
,
model_info
,
agent_model_list
]()
mutable
{
LOG_VERBOSE
(
2
)
<<
"OnDestroy callback() '"
<<
model_name
<<
"' version "
<<
version
;
LOG_INFO
<<
"successfully unloaded '"
<<
model_name
<<
"' version "
<<
version
;
// Update model state as it is fully unloaded
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
model_info
->
state_
=
ModelReadyState
::
UNAVAILABLE
;
model_info
->
state_reason_
=
"unloaded"
;
}
// Check if the model info is in background, if so, remove from the
// map
std
::
lock_guard
<
std
::
mutex
>
lk
(
this
->
map_mtx_
);
auto
it
=
this
->
background_models_
.
find
((
uintptr_t
)
model_info
);
if
(
it
!=
this
->
background_models_
.
end
())
{
this
->
background_models_
.
erase
(
it
);
}
}));
}
else
{
LOG_ERROR
<<
"failed to load '"
<<
model_name
<<
"' version "
<<
version
<<
": "
<<
status
.
AsString
();
model_info
->
state_
=
ModelReadyState
::
UNAVAILABLE
;
model_info
->
state_reason_
=
status
.
AsString
();
}
}
void
ModelLifeCycle
::
OnLoadComplete
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
std
::
function
<
void
(
Status
)
>
OnComplete
,
std
::
shared_ptr
<
LoadTracker
>
load_tracker
)
{
std
::
lock_guard
<
std
::
mutex
>
tracker_lock
(
load_tracker
->
mtx_
);
++
load_tracker
->
completed_version_cnt_
;
load_tracker
->
load_set_
[
version
]
=
model_info
;
// Version will not be marked ready until all versions are
// ready, this simplify the unloading when one version fails to load as
// all other versions won't have inflight requests
if
(
model_info
->
state_
!=
ModelReadyState
::
LOADING
)
{
load_tracker
->
load_failed_
=
true
;
load_tracker
->
reason_
+=
(
"version "
+
std
::
to_string
(
version
)
+
" is at "
+
ModelReadyStateString
(
model_info
->
state_
)
+
" state: "
+
model_info
->
state_reason_
+
";"
);
}
// Check if all versions are completed and finish the load
if
(
load_tracker
->
completed_version_cnt_
==
load_tracker
->
affected_version_cnt_
)
{
// hold 'map_mtx_' as there will be change onto the model info map
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
// Check if the load is the latest frontground action on the model
for
(
const
auto
&
version_info
:
it
->
second
)
{
if
(
version_info
.
second
->
last_update_ns_
>
load_tracker
->
last_update_ns_
)
{
load_tracker
->
load_failed_
=
true
;
load_tracker
->
reason_
=
"Newer operation has been applied to the model lifecycle, current "
"load operation is out-dated."
;
break
;
}
}
if
(
load_tracker
->
load_failed_
)
{
// Move agent list out of ModelInfo as it needs to be invoked
// after all ModelInfos are reset
std
::
shared_ptr
<
TritonRepoAgentModelList
>
lagent_list
;
if
(
model_info
->
agent_model_list_
)
{
lagent_list
=
std
::
move
(
model_info
->
agent_model_list_
);
}
// If any of the versions fails to load, abort the load and unload
// all newly loaded versions
for
(
auto
&
loaded
:
load_tracker
->
load_set_
)
{
// Unload directly, the object is being managed either in frontground
// or background
std
::
lock_guard
<
std
::
mutex
>
lock
(
loaded
.
second
->
mtx_
);
if
(
loaded
.
second
->
model_
!=
nullptr
)
{
loaded
.
second
->
Release
();
}
}
if
(
lagent_list
)
{
auto
status
=
lagent_list
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_LOAD_FAIL
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_FAIL: "
<<
status
.
AsString
();
}
}
}
else
{
// Unload any previous loaded versions that are still available
for
(
auto
&
version_info
:
it
->
second
)
{
auto
&
mi
=
version_info
.
second
;
std
::
lock_guard
<
std
::
mutex
>
info_lk
(
mi
->
mtx_
);
if
((
mi
->
state_
==
ModelReadyState
::
READY
)
&&
(
mi
->
last_update_ns_
<
load_tracker
->
last_update_ns_
))
{
if
(
mi
->
agent_model_list_
!=
nullptr
)
{
auto
status
=
mi
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_UNLOAD: "
<<
status
.
AsString
();
}
}
mi
->
Release
();
}
}
// Mark current versions ready and track info in foreground
for
(
auto
&
loaded
:
load_tracker
->
load_set_
)
{
std
::
lock_guard
<
std
::
mutex
>
curr_info_lk
(
loaded
.
second
->
mtx_
);
loaded
.
second
->
state_
=
ModelReadyState
::
READY
;
model_info
->
state_reason_
.
clear
();
LOG_INFO
<<
"successfully loaded '"
<<
model_name
<<
"' version "
<<
version
;
auto
bit
=
background_models_
.
find
((
uintptr_t
)
loaded
.
second
);
// Check if the version model is loaded in background, if so,
// replace and unload the current serving version
if
(
bit
!=
background_models_
.
end
())
{
auto
vit
=
it
->
second
.
find
(
loaded
.
first
);
// Need to lock the previous model info for in case the model is
// loading / unloading, this ensure the model state is consistent
// even when the load / unload is completed.
std
::
lock_guard
<
std
::
mutex
>
prev_info_lk
(
vit
->
second
->
mtx_
);
// swap previous info into local unique pointer
auto
linfo
=
std
::
move
(
bit
->
second
);
vit
->
second
.
swap
(
linfo
);
background_models_
.
erase
(
bit
);
// if previous info is under change, put into 'background_models_'
if
((
linfo
->
state_
==
ModelReadyState
::
LOADING
)
||
(
linfo
->
state_
==
ModelReadyState
::
UNLOADING
))
{
ModelInfo
*
key
=
linfo
.
get
();
background_models_
[(
uintptr_t
)
key
]
=
std
::
move
(
linfo
);
}
}
}
if
(
model_info
->
agent_model_list_
)
{
auto
status
=
model_info
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE: "
<<
status
.
AsString
();
}
}
}
if
(
OnComplete
!=
nullptr
)
{
OnComplete
(
load_tracker
->
load_failed_
?
Status
(
Status
::
Code
::
INVALID_ARG
,
load_tracker
->
reason_
)
:
Status
::
Success
);
}
}
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_lifecycle.h
0 → 100644
View file @
0a21fff9
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "repo_agent.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "triton/common/thread_pool.h"
namespace
triton
{
namespace
core
{
struct
ModelLifeCycleOptions
{
explicit
ModelLifeCycleOptions
(
const
double
min_compute_capability
,
const
triton
::
common
::
BackendCmdlineConfigMap
&
backend_cmdline_config_map
,
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map
,
const
unsigned
int
model_load_thread_count
)
:
min_compute_capability_
(
min_compute_capability
),
backend_cmdline_config_map_
(
backend_cmdline_config_map
),
host_policy_map_
(
host_policy_map
),
model_load_thread_count_
(
model_load_thread_count
)
{
}
// The minimum supported CUDA compute capability.
const
double
min_compute_capability_
;
// The backend configuration settings specified on the command-line
const
triton
::
common
::
BackendCmdlineConfigMap
&
backend_cmdline_config_map_
;
// The host policy setting used when loading models.
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map_
;
// Number of the threads to use for concurrently loading models
const
unsigned
int
model_load_thread_count_
;
};
/// Readiness status for models.
enum
class
ModelReadyState
{
// The model is in an unknown state. The model is not available for
// inferencing.
UNKNOWN
,
// The model is ready and available for inferencing.
READY
,
// The model is unavailable, indicating that the model failed to
// load or has been implicitly or explicitly unloaded. The model is
// not available for inferencing.
UNAVAILABLE
,
// The model is being loaded by the inference server. The model is
// not available for inferencing.
LOADING
,
// The model is being unloaded by the inference server. The model is
// not available for inferencing.
UNLOADING
};
/// Get the string representation for a ModelReadyState
const
std
::
string
&
ModelReadyStateString
(
ModelReadyState
state
);
using
VersionStateMap
=
std
::
map
<
int64_t
,
std
::
pair
<
ModelReadyState
,
std
::
string
>>
;
using
ModelStateMap
=
std
::
map
<
std
::
string
,
VersionStateMap
>
;
// Helper class to manage the lifecycle of a list of associated agent models
class
TritonRepoAgentModelList
{
public:
TritonRepoAgentModelList
()
:
last_action_type_
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
){};
~
TritonRepoAgentModelList
()
{
// Using destructor to finish the unload lifecycle without
// explicitly managing the last step in ModelLifecycle.
if
(
last_action_type_
==
TRITONREPOAGENT_ACTION_UNLOAD
)
{
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
);
}
}
Status
AddAgentModel
(
std
::
unique_ptr
<
TritonRepoAgentModel
>&&
agent_model
)
{
agent_models_
.
emplace_back
(
std
::
move
(
agent_model
));
return
Status
::
Success
;
}
size_t
Size
()
{
return
agent_models_
.
size
();
}
TritonRepoAgentModel
*
Back
()
{
return
agent_models_
.
back
().
get
();
}
Status
InvokeAgentModels
(
const
TRITONREPOAGENT_ActionType
action_type
)
{
// Special handling for the current model lifecycle implementation,
// the repo agent may be asked to perform UNLOAD action multiple times,
// and the requests after the first should be ignored.
const
bool
first_unload
=
(
action_type
==
TRITONREPOAGENT_ACTION_UNLOAD
)
&&
(
last_action_type_
!=
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
first_unload
)
{
return
Status
::
Success
;
}
last_action_type_
=
action_type
;
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
case
TRITONREPOAGENT_ACTION_UNLOAD
:
{
for
(
size_t
idx
=
0
;
idx
<
agent_models_
.
size
();
++
idx
)
{
RETURN_IF_ERROR
(
agent_models_
[
idx
]
->
InvokeAgent
(
action_type
));
}
break
;
}
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
{
// reverse order
for
(
size_t
one_pass_idx
=
agent_models_
.
size
();
one_pass_idx
>
0
;
--
one_pass_idx
)
{
RETURN_IF_ERROR
(
agent_models_
[
one_pass_idx
-
1
]
->
InvokeAgent
(
action_type
));
}
break
;
}
}
return
Status
::
Success
;
}
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentModelList
);
std
::
vector
<
std
::
unique_ptr
<
TritonRepoAgentModel
>>
agent_models_
;
TRITONREPOAGENT_ActionType
last_action_type_
;
};
class
InferenceServer
;
class
Model
;
class
ModelLifeCycle
{
public:
static
Status
Create
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
,
std
::
unique_ptr
<
ModelLifeCycle
>*
life_cycle
);
~
ModelLifeCycle
()
{
// Explicitly clean up thread pool first to clean up any pending callbacks
// that may modify model lifecycle members
load_pool_
.
reset
();
map_
.
clear
();
}
// Start loading model with specified versions asynchronously.
// All versions that are being served will be unloaded only after
// the load is finished sucessfully.
Status
AsyncLoad
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
std
::
shared_ptr
<
TritonRepoAgentModelList
>&
agent_model_list
,
std
::
function
<
void
(
Status
)
>&&
OnComplete
);
// Unload model asynchronously.
Status
AsyncUnload
(
const
std
::
string
&
model_name
);
// Get specified version of the model. Latest ready version will
// be retrieved if 'version' is -1. Return error if the version specified is
// not found or it is not ready.
Status
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
std
::
shared_ptr
<
Model
>*
model
);
// Get the ModelStateMap representation of the live models. A model is
// live if at least one of the versions is not unknown nor unavailable.
// If 'strict_readiness' is true, a model is only live if
// at least one of the versions is ready.
const
ModelStateMap
LiveModelStates
(
bool
strict_readiness
=
false
);
// Get the ModelStateMap representation of the models.
const
ModelStateMap
ModelStates
();
// Get the VersionStateMap representation of the specified model.
const
VersionStateMap
VersionStates
(
const
std
::
string
&
model_name
);
// Get the state of a specific model version.
Status
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
);
// Instruct the model to stop accepting new inference requests.
Status
StopAllModels
();
// Return the number of in-flight inference if any, model versions
// that don't have in-flight inferences will not be included.
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
InflightStatus
();
private:
struct
ModelInfo
{
ModelInfo
(
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
uint64_t
last_update_ns
)
:
model_config_
(
model_config
),
model_path_
(
model_path
),
#ifdef TRITON_ENABLE_ENSEMBLE
is_ensemble_
(
model_config
.
platform
()
==
kEnsemblePlatform
),
#else
is_ensemble_
(
false
),
#endif // TRITON_ENABLE_ENSEMBLE
last_update_ns_
(
last_update_ns
),
state_
(
ModelReadyState
::
UNKNOWN
)
{
}
// Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in
// model state. Note that 'mtx_' should be acquired before invoking this
// function to prevent possible data race.
void
Release
()
{
state_
=
ModelReadyState
::
UNLOADING
;
state_reason_
.
clear
();
agent_model_list_
.
reset
();
model_
.
reset
();
}
const
inference
::
ModelConfig
model_config_
;
const
std
::
string
model_path_
;
const
bool
is_ensemble_
;
std
::
mutex
mtx_
;
uint64_t
last_update_ns_
;
ModelReadyState
state_
;
std
::
string
state_reason_
;
// flyweight
std
::
shared_ptr
<
TritonRepoAgentModelList
>
agent_model_list_
;
std
::
shared_ptr
<
Model
>
model_
;
};
struct
LoadTracker
{
LoadTracker
(
const
size_t
affected_version_cnt
,
const
uint64_t
last_update_ns
)
:
last_update_ns_
(
last_update_ns
),
affected_version_cnt_
(
affected_version_cnt
),
load_failed_
(
false
),
completed_version_cnt_
(
0
)
{
}
const
uint64_t
last_update_ns_
;
const
size_t
affected_version_cnt_
;
std
::
mutex
mtx_
;
bool
load_failed_
;
std
::
string
reason_
;
size_t
completed_version_cnt_
;
std
::
map
<
int64_t
,
ModelInfo
*>
load_set_
;
};
ModelLifeCycle
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
)
:
server_
(
server
),
min_compute_capability_
(
options
.
min_compute_capability_
),
cmdline_config_map_
(
options
.
backend_cmdline_config_map_
),
host_policy_map_
(
options
.
host_policy_map_
)
{
load_pool_
.
reset
(
new
triton
::
common
::
ThreadPool
(
std
::
max
(
1u
,
options
.
model_load_thread_count_
)));
}
void
CreateModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
const
bool
is_config_provided
);
// Callback function template for model load.
// 'OnComplete' needs to be passed by value for now as there can be
// multiple versions to be loaded and each holds a copy of
// the 'OnComplete' callback.
void
OnLoadComplete
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
std
::
function
<
void
(
Status
)
>
OnComplete
,
std
::
shared_ptr
<
LoadTracker
>
load_tracker
);
// Mutex for 'map_' and 'background_models_'
std
::
mutex
map_mtx_
;
using
VersionMap
=
std
::
map
<
int64_t
,
std
::
unique_ptr
<
ModelInfo
>>
;
using
ModelMap
=
std
::
map
<
std
::
string
,
VersionMap
>
;
ModelMap
map_
;
// Models that are being loaded / unloaded in background
std
::
map
<
uintptr_t
,
std
::
unique_ptr
<
ModelInfo
>>
background_models_
;
InferenceServer
*
server_
;
const
double
min_compute_capability_
;
const
triton
::
common
::
BackendCmdlineConfigMap
cmdline_config_map_
;
const
triton
::
common
::
HostPolicyCmdlineConfigMap
host_policy_map_
;
// Fixed-size thread pool to load models at specified concurrency
std
::
unique_ptr
<
triton
::
common
::
ThreadPool
>
load_pool_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_repository_manager.cc
0 → 100644
View file @
0a21fff9
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_repository_manager.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "ensemble_utils.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace
triton
{
namespace
core
{
namespace
{
static
std
::
string
file_prefix
=
"file:"
;
// Internal repo agent used for model file override
class
LocalizeRepoAgent
:
public
TritonRepoAgent
{
public:
LocalizeRepoAgent
()
:
TritonRepoAgent
(
"ModelRepositoryManager::LocalizeRepoAgent"
)
{
// Callbacks below interact with TritonRepoAgentModel directly knowing that
// it is the internal implementation of TRITONREPOAGENT_AgentModel
model_action_fn_
=
[](
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ActionType
action_type
)
->
TRITONSERVER_Error
*
{
auto
agent_model
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
{
// localize the override files for model loading,
// as currently the model is expected to load from local directory
const
char
*
temp_dir_cstr
=
nullptr
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
AcquireMutableLocation
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
&
temp_dir_cstr
));
const
std
::
string
temp_dir
=
temp_dir_cstr
;
const
auto
&
files
=
*
reinterpret_cast
<
std
::
vector
<
const
InferenceParameter
*>*>
(
agent_model
->
State
());
bool
found_config
=
false
;
for
(
const
auto
&
file
:
files
)
{
if
(
file
->
Name
()
==
"config"
)
{
if
(
file
->
Type
()
!=
TRITONSERVER_PARAMETER_STRING
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"Config parameter 'config' must have string type for its "
"value"
);
}
inference
::
ModelConfig
config
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
JsonToModelConfig
(
file
->
ValueString
(),
1
/* config_version */
,
&
config
));
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
WriteTextProto
(
JoinPath
({
temp_dir
,
kModelConfigPbTxt
}),
config
));
found_config
=
true
;
}
else
if
(
file
->
Name
().
rfind
(
file_prefix
,
0
)
==
0
)
{
if
(
file
->
Type
()
!=
TRITONSERVER_PARAMETER_BYTES
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
(
std
::
string
(
"File parameter '"
)
+
file
->
Name
()
+
"' must have bytes type for its value"
)
.
c_str
());
}
// Save model file to the instructed directory
// mkdir
const
std
::
string
file_path
=
JoinPath
({
temp_dir
,
file
->
Name
().
substr
(
file_prefix
.
size
())});
const
std
::
string
dir
=
DirName
(
file_path
);
bool
dir_exist
=
false
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
FileExists
(
dir
,
&
dir_exist
));
if
(
dir_exist
)
{
bool
is_dir
=
false
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
IsDirectory
(
dir
,
&
is_dir
));
if
(
!
is_dir
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
(
std
::
string
(
"Invalid file parameter '"
)
+
file
->
Name
()
+
"', directory has been created as a file"
)
.
c_str
());
}
}
else
{
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
MakeDirectory
(
dir
,
true
/* recursive */
));
}
// write
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
WriteBinaryFile
(
file_path
,
reinterpret_cast
<
const
char
*>
(
file
->
ValuePointer
()),
file
->
ValueByteSize
()));
}
}
if
(
!
found_config
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"Load parameter 'config' must be specified for model file "
"override"
);
}
// Commit the temporary directory
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
SetLocation
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
temp_dir_cstr
));
break
;
}
default:
break
;
}
return
nullptr
;
// success
};
model_fini_fn_
=
[](
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
)
->
TRITONSERVER_Error
*
{
auto
agent_model
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
DeleteMutableLocation
());
return
nullptr
;
// success
};
}
};
Status
CreateAgentModelListWithLoadAction
(
const
inference
::
ModelConfig
&
original_model_config
,
const
std
::
string
&
original_model_path
,
std
::
shared_ptr
<
TritonRepoAgentModelList
>*
agent_model_list
)
{
if
(
original_model_config
.
has_model_repository_agents
())
{
// Trick to append user specified repo agent on top of internal ones
std
::
shared_ptr
<
TritonRepoAgentModelList
>
lagent_model_list
;
if
(
*
agent_model_list
!=
nullptr
)
{
lagent_model_list
=
std
::
move
(
*
agent_model_list
);
}
else
{
lagent_model_list
.
reset
(
new
TritonRepoAgentModelList
());
}
FileSystemType
filesystem_type
;
RETURN_IF_ERROR
(
GetFileSystemType
(
original_model_path
,
&
filesystem_type
));
TRITONREPOAGENT_ArtifactType
artifact_type
=
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
;
if
(
filesystem_type
!=
FileSystemType
::
LOCAL
)
{
artifact_type
=
TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM
;
}
const
char
*
location
=
original_model_path
.
c_str
();
inference
::
ModelConfig
model_config
=
original_model_config
;
for
(
const
auto
&
agent_config
:
original_model_config
.
model_repository_agents
().
agents
())
{
std
::
shared_ptr
<
TritonRepoAgent
>
agent
;
RETURN_IF_ERROR
(
TritonRepoAgentManager
::
CreateAgent
(
agent_config
.
name
(),
&
agent
));
TritonRepoAgent
::
Parameters
agent_params
;
for
(
const
auto
&
parameter
:
agent_config
.
parameters
())
{
agent_params
.
emplace_back
(
parameter
.
first
,
parameter
.
second
);
}
std
::
unique_ptr
<
TritonRepoAgentModel
>
agent_model
;
if
(
lagent_model_list
->
Size
()
!=
0
)
{
lagent_model_list
->
Back
()
->
Location
(
&
artifact_type
,
&
location
);
const
auto
config_path
=
JoinPath
({
location
,
kModelConfigPbTxt
});
if
(
!
ReadTextProto
(
config_path
,
&
model_config
).
IsOk
())
{
model_config
.
Clear
();
}
}
RETURN_IF_ERROR
(
TritonRepoAgentModel
::
Create
(
artifact_type
,
location
,
model_config
,
agent
,
agent_params
,
&
agent_model
));
RETURN_IF_ERROR
(
agent_model
->
InvokeAgent
(
TRITONREPOAGENT_ACTION_LOAD
));
lagent_model_list
->
AddAgentModel
(
std
::
move
(
agent_model
));
}
*
agent_model_list
=
std
::
move
(
lagent_model_list
);
}
return
Status
::
Success
;
}
int64_t
GetModifiedTime
(
const
std
::
string
&
path
)
{
// If there is an error in any step the fall-back default
// modification time is 0. This means that in error cases 'path'
// will show as not modified. This is the safe fall-back to avoid
// assuming a model is constantly being modified.
bool
path_is_dir
;
Status
status
=
IsDirectory
(
path
,
&
path_is_dir
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
// If 'path' is a file return its mtime. Otherwise, using the modification
// time of the directory as baseline in case of file deletion
int64_t
mtime
=
0
;
status
=
FileModificationTime
(
path
,
&
mtime
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
if
(
!
path_is_dir
)
{
return
mtime
;
}
// 'path' is a directory. Return the most recent mtime of the
// contents of the directory.
std
::
set
<
std
::
string
>
contents
;
status
=
GetDirectoryContents
(
path
,
&
contents
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
for
(
const
auto
&
child
:
contents
)
{
const
auto
full_path
=
JoinPath
({
path
,
child
});
mtime
=
std
::
max
(
mtime
,
GetModifiedTime
(
full_path
));
}
return
mtime
;
}
// Return true if any file in the subdirectory root at 'path' has been
// modified more recently than 'last'. Return the most-recent modified
// time in 'last'.
bool
IsModified
(
const
std
::
string
&
path
,
int64_t
*
last_ns
)
{
const
int64_t
repo_ns
=
GetModifiedTime
(
path
);
bool
modified
=
repo_ns
>
*
last_ns
;
*
last_ns
=
repo_ns
;
return
modified
;
}
}
// namespace
struct
ModelRepositoryManager
::
ModelInfo
{
ModelInfo
(
const
int64_t
mtime_nsec
,
const
int64_t
prev_mtime_ns
,
const
std
::
string
&
model_path
)
:
mtime_nsec_
(
mtime_nsec
),
prev_mtime_ns_
(
prev_mtime_ns
),
explicitly_load_
(
true
),
model_path_
(
model_path
),
is_config_provided_
(
false
)
{
}
ModelInfo
()
:
mtime_nsec_
(
0
),
prev_mtime_ns_
(
0
),
explicitly_load_
(
true
),
is_config_provided_
(
false
)
{
}
int64_t
mtime_nsec_
;
int64_t
prev_mtime_ns_
;
bool
explicitly_load_
;
inference
::
ModelConfig
model_config_
;
std
::
string
model_path_
;
// Temporary location to hold agent model list before creating the model
// the ownership must transfer to ModelLifeCycle to ensure
// the agent model life cycle is handled properly.
std
::
shared_ptr
<
TritonRepoAgentModelList
>
agent_model_list_
;
bool
is_config_provided_
;
};
ModelRepositoryManager
::
ModelRepositoryManager
(
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
bool
autofill
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
)
:
repository_paths_
(
repository_paths
),
autofill_
(
autofill
),
polling_enabled_
(
polling_enabled
),
model_control_enabled_
(
model_control_enabled
),
min_compute_capability_
(
min_compute_capability
),
model_life_cycle_
(
std
::
move
(
life_cycle
))
{
}
ModelRepositoryManager
::~
ModelRepositoryManager
()
{}
Status
ModelRepositoryManager
::
Create
(
InferenceServer
*
server
,
const
std
::
string
&
server_version
,
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
std
::
set
<
std
::
string
>&
startup_models
,
const
bool
strict_model_config
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
ModelLifeCycleOptions
&
life_cycle_options
,
std
::
unique_ptr
<
ModelRepositoryManager
>*
model_repository_manager
)
{
// The rest only matters if repository path is valid directory
for
(
const
auto
&
path
:
repository_paths
)
{
bool
path_is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
path_is_dir
));
if
(
!
path_is_dir
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"repository path is not a valid directory"
);
}
}
if
(
polling_enabled
&&
model_control_enabled
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"cannot enable both polling and explicit model control"
);
}
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
;
RETURN_IF_ERROR
(
ModelLifeCycle
::
Create
(
server
,
life_cycle_options
,
&
life_cycle
));
// Not setting the smart pointer directly to simplify clean up
std
::
unique_ptr
<
ModelRepositoryManager
>
local_manager
(
new
ModelRepositoryManager
(
repository_paths
,
!
strict_model_config
,
polling_enabled
,
model_control_enabled
,
life_cycle_options
.
min_compute_capability_
,
std
::
move
(
life_cycle
)));
*
model_repository_manager
=
std
::
move
(
local_manager
);
// Support loading all models on startup in explicit model control mode with
// special startup_model name "*". This does not imply support for pattern
// matching in model names.
bool
load_all_models_on_startup
=
false
;
if
((
startup_models
.
find
(
"*"
)
!=
startup_models
.
end
())
&&
model_control_enabled
)
{
if
(
startup_models
.
size
()
>
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Wildcard model name '*' must be the ONLY startup model "
"if specified at all."
);
}
load_all_models_on_startup
=
true
;
}
bool
all_models_polled
=
true
;
if
(
!
model_control_enabled
||
load_all_models_on_startup
)
{
// only error happens before model load / unload will be return
// model loading / unloading error will be printed but ignored
RETURN_IF_ERROR
(
(
*
model_repository_manager
)
->
PollAndUpdateInternal
(
&
all_models_polled
));
}
else
{
// Load each specified startup_model
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
models
;
for
(
const
auto
&
model_name
:
startup_models
)
{
models
[
model_name
];
}
RETURN_IF_ERROR
(
(
*
model_repository_manager
)
->
LoadUnloadModels
(
models
,
ActionType
::
LOAD
,
false
,
&
all_models_polled
));
}
if
(
!
all_models_polled
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
// Some models may failed to be loaded after model manager is created,
// return proper error and let function caller decide whether to proceed.
for
(
const
auto
&
model
:
(
*
model_repository_manager
)
->
infos_
)
{
const
auto
version_states
=
(
*
model_repository_manager
)
->
model_life_cycle_
->
VersionStates
(
model
.
first
);
// Return general error message, detail of each model's loading state
// is logged separately.
if
(
version_states
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
for
(
const
auto
&
state
:
version_states
)
{
if
(
state
.
second
.
first
!=
ModelReadyState
::
READY
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
PollAndUpdate
()
{
if
(
!
polling_enabled_
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"polling is disabled"
);
}
bool
all_models_polled
;
return
PollAndUpdateInternal
(
&
all_models_polled
);
}
Status
ModelRepositoryManager
::
PollAndUpdateInternal
(
bool
*
all_models_polled
)
{
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
std
::
set
<
std
::
string
>
added
,
deleted
,
modified
,
unmodified
;
// We don't modify 'infos_' in place to minimize how long we need to
// hold the lock and also prevent any partial changes to do an error
// during processing.
ModelInfoMap
new_infos
;
// Each subdirectory of repository path is a model directory from
// which we read the model configuration.
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
subdirs
;
RETURN_IF_ERROR
(
Poll
(
subdirs
,
&
added
,
&
deleted
,
&
modified
,
&
unmodified
,
&
new_infos
,
all_models_polled
));
// Anything in 'infos_' that is not in "added", "modified", or
// "unmodified" is deleted.
for
(
const
auto
&
pr
:
infos_
)
{
if
((
added
.
find
(
pr
.
first
)
==
added
.
end
())
&&
(
modified
.
find
(
pr
.
first
)
==
modified
.
end
())
&&
(
unmodified
.
find
(
pr
.
first
)
==
unmodified
.
end
()))
{
deleted
.
insert
(
pr
.
first
);
}
}
// Nothing to do if no model adds, deletes or modifies.
if
(
added
.
empty
()
&&
deleted
.
empty
()
&&
modified
.
empty
())
{
return
Status
::
Success
;
}
infos_
.
swap
(
new_infos
);
UpdateDependencyGraph
(
added
,
deleted
,
modified
);
for
(
const
auto
&
name
:
deleted
)
{
model_life_cycle_
->
AsyncUnload
(
name
);
}
// model loading / unloading error will be printed but ignored
LoadModelByDependency
();
return
Status
::
Success
;
}
std
::
map
<
std
::
string
,
Status
>
ModelRepositoryManager
::
LoadModelByDependency
()
{
std
::
map
<
std
::
string
,
Status
>
res
;
struct
ModelState
{
ModelState
(
DependencyNode
*
node
)
:
node_
(
node
),
status_
(
Status
::
Success
)
{}
DependencyNode
*
node_
;
Status
status_
;
std
::
promise
<
void
>
ready_
;
};
NodeSet
loaded_models
;
auto
set_pair
=
ModelsToLoadUnload
(
loaded_models
);
// Loop until all model are loaded / unloaded
while
((
!
set_pair
.
first
.
empty
())
||
(
!
set_pair
.
second
.
empty
()))
{
loaded_models
.
clear
();
// Unload invalid models first
for
(
auto
&
invalid_model
:
set_pair
.
second
)
{
model_life_cycle_
->
AsyncUnload
(
invalid_model
->
model_name_
);
LOG_ERROR
<<
invalid_model
->
status_
.
AsString
();
invalid_model
->
loaded_versions_
=
std
::
set
<
int64_t
>
();
loaded_models
.
emplace
(
invalid_model
);
}
// load valid models and wait for load results
std
::
vector
<
std
::
unique_ptr
<
ModelState
>>
model_states
;
for
(
auto
&
valid_model
:
set_pair
.
first
)
{
model_states
.
emplace_back
(
new
ModelState
(
valid_model
));
auto
model_state
=
model_states
.
back
().
get
();
const
auto
itr
=
infos_
.
find
(
valid_model
->
model_name_
);
auto
status
=
model_life_cycle_
->
AsyncLoad
(
valid_model
->
model_name_
,
itr
->
second
->
model_path_
,
valid_model
->
model_config_
,
itr
->
second
->
is_config_provided_
,
itr
->
second
->
agent_model_list_
,
[
model_state
](
Status
load_status
)
{
model_state
->
status_
=
load_status
;
model_state
->
ready_
.
set_value
();
});
if
(
!
status
.
IsOk
())
{
model_state
->
status_
=
status
;
model_state
->
ready_
.
set_value
();
LOG_ERROR
<<
"failed to load model '"
<<
valid_model
->
model_name_
<<
"': "
<<
status
.
Message
();
}
loaded_models
.
emplace
(
valid_model
);
}
for
(
auto
&
model_state
:
model_states
)
{
model_state
->
ready_
.
get_future
().
wait
();
res
[
model_state
->
node_
->
model_name_
]
=
model_state
->
status_
;
const
auto
version_state
=
model_life_cycle_
->
VersionStates
(
model_state
->
node_
->
model_name_
);
model_state
->
node_
->
loaded_versions_
.
clear
();
for
(
const
auto
&
vs
:
version_state
)
{
if
(
vs
.
second
.
first
==
ModelReadyState
::
READY
)
{
model_state
->
node_
->
loaded_versions_
.
emplace
(
vs
.
first
);
}
}
// If the model failed to load, should revert the timestamp to
// ensure the next load request will attempt to load the model again
// for operation consistency.
if
(
!
model_state
->
status_
.
IsOk
())
{
auto
&
model_info
=
infos_
.
find
(
model_state
->
node_
->
model_name_
)
->
second
;
model_info
->
mtime_nsec_
=
model_info
->
prev_mtime_ns_
;
}
}
set_pair
=
ModelsToLoadUnload
(
loaded_models
);
}
// Clear temporary stored agent model list after all loads are triggerred
for
(
auto
&
info
:
infos_
)
{
info
.
second
->
agent_model_list_
.
reset
();
}
return
res
;
}
Status
ModelRepositoryManager
::
LoadUnloadModel
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"explicit model load / unload is not allowed if polling is enabled"
);
}
if
(
models
.
size
()
>
1
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"explicit load / unload multiple models is not currently supported"
);
}
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
bool
polled
=
true
;
RETURN_IF_ERROR
(
LoadUnloadModels
(
models
,
type
,
unload_dependents
,
&
polled
));
// Check if model is loaded / unloaded properly
const
auto
&
model_name
=
models
.
begin
()
->
first
;
if
(
!
polled
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', failed to poll from model repository"
);
}
const
auto
version_states
=
model_life_cycle_
->
VersionStates
(
model_name
);
if
(
type
==
ActionType
::
LOAD
)
{
if
(
version_states
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', no version is available"
);
}
auto
it
=
infos_
.
find
(
model_name
);
if
(
it
==
infos_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', failed to poll from model repository"
);
}
}
else
{
std
::
string
ready_version_str
;
for
(
const
auto
&
version_state
:
version_states
)
{
if
(
version_state
.
second
.
first
==
ModelReadyState
::
READY
)
{
ready_version_str
+=
std
::
to_string
(
version_state
.
first
);
ready_version_str
+=
","
;
}
}
if
(
!
ready_version_str
.
empty
())
{
ready_version_str
.
pop_back
();
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to unload '"
+
model_name
+
"', versions that are still available: "
+
ready_version_str
);
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
LoadUnloadModels
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
,
bool
*
all_models_polled
)
{
auto
status
=
Status
::
Success
;
*
all_models_polled
=
true
;
// Update ModelInfo related to file system accordingly
std
::
set
<
std
::
string
>
added
,
deleted
,
modified
,
unmodified
;
{
if
(
type
==
ActionType
::
UNLOAD
)
{
for
(
const
auto
&
model
:
models
)
{
deleted
.
insert
(
model
.
first
);
}
}
// ActionType::LOAD and in model control mode
else
{
std
::
set
<
std
::
string
>
checked_models
;
auto
current_models
=
models
;
for
(
const
auto
&
model
:
models
)
{
checked_models
.
emplace
(
model
.
first
);
}
ModelInfoMap
new_infos
;
#ifdef TRITON_ENABLE_ENSEMBLE
bool
first_iteration
=
true
;
#endif // TRITON_ENABLE_ENSEMBLE
while
(
!
current_models
.
empty
())
{
bool
polled
=
true
;
RETURN_IF_ERROR
(
Poll
(
current_models
,
&
added
,
&
deleted
,
&
modified
,
&
unmodified
,
&
new_infos
,
&
polled
));
*
all_models_polled
&=
polled
;
// More models should be polled if the polled models are ensembles
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
next_models
;
#ifdef TRITON_ENABLE_ENSEMBLE
for
(
const
auto
&
model
:
current_models
)
{
auto
it
=
new_infos
.
find
(
model
.
first
);
// Some models may be marked as deleted and not in 'new_infos'
if
(
it
!=
new_infos
.
end
())
{
it
->
second
->
explicitly_load_
=
first_iteration
;
const
auto
&
config
=
it
->
second
->
model_config_
;
if
(
config
.
has_ensemble_scheduling
())
{
for
(
const
auto
&
step
:
config
.
ensemble_scheduling
().
step
())
{
bool
need_poll
=
checked_models
.
emplace
(
step
.
model_name
()).
second
;
if
(
need_poll
)
{
next_models
[
step
.
model_name
()];
}
}
}
}
}
first_iteration
=
false
;
#endif // TRITON_ENABLE_ENSEMBLE
current_models
.
swap
(
next_models
);
}
// Only update the infos when all validation is completed
for
(
const
auto
&
model_name
:
added
)
{
auto
nitr
=
new_infos
.
find
(
model_name
);
infos_
.
emplace
(
model_name
,
std
::
move
(
nitr
->
second
));
}
for
(
const
auto
&
model_name
:
modified
)
{
auto
nitr
=
new_infos
.
find
(
model_name
);
auto
itr
=
infos_
.
find
(
model_name
);
itr
->
second
=
std
::
move
(
nitr
->
second
);
}
}
}
std
::
set
<
std
::
string
>
deleted_dependents
;
// Update dependency graph and load
UpdateDependencyGraph
(
added
,
deleted
,
modified
,
unload_dependents
?
&
deleted_dependents
:
nullptr
);
// The models are in 'deleted' either when they are asked to be unloaded or
// they are not found / are duplicated across all model repositories.
// In all cases, should unload them and remove from 'infos_' explicitly.
for
(
const
auto
&
name
:
(
unload_dependents
?
deleted_dependents
:
deleted
))
{
infos_
.
erase
(
name
);
model_life_cycle_
->
AsyncUnload
(
name
);
}
// load / unload the models affected, and check the load status of
// the requested models
const
auto
&
load_status
=
LoadModelByDependency
();
if
(
status
.
IsOk
()
&&
(
type
==
ActionType
::
LOAD
))
{
std
::
string
load_error_message
=
""
;
for
(
const
auto
&
model
:
models
)
{
auto
it
=
load_status
.
find
(
model
.
first
);
// If 'model.first' not in load status, it means the (re-)load is not
// necessary because there is no change in the model's directory
if
((
it
!=
load_status
.
end
())
&&
!
it
->
second
.
IsOk
())
{
load_error_message
+=
(
"load failed for model '"
+
model
.
first
+
"': "
+
it
->
second
.
Message
()
+
"
\n
"
);
}
}
if
(
!
load_error_message
.
empty
())
{
status
=
Status
(
Status
::
Code
::
INVALID_ARG
,
load_error_message
);
}
}
return
status
;
}
Status
ModelRepositoryManager
::
UnloadAllModels
()
{
Status
status
;
for
(
const
auto
&
name_info
:
infos_
)
{
Status
unload_status
=
model_life_cycle_
->
AsyncUnload
(
name_info
.
first
);
if
(
!
unload_status
.
IsOk
())
{
status
=
Status
(
unload_status
.
ErrorCode
(),
"Failed to gracefully unload models: "
+
unload_status
.
Message
());
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
StopAllModels
()
{
return
model_life_cycle_
->
StopAllModels
();
}
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
ModelRepositoryManager
::
InflightStatus
()
{
return
model_life_cycle_
->
InflightStatus
();
}
const
ModelStateMap
ModelRepositoryManager
::
LiveModelStates
(
bool
strict_readiness
)
{
return
model_life_cycle_
->
LiveModelStates
(
strict_readiness
);
}
const
ModelStateMap
ModelRepositoryManager
::
ModelStates
()
{
return
model_life_cycle_
->
ModelStates
();
}
const
VersionStateMap
ModelRepositoryManager
::
VersionStates
(
const
std
::
string
&
model_name
)
{
return
model_life_cycle_
->
VersionStates
(
model_name
);
}
Status
ModelRepositoryManager
::
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
)
{
return
model_life_cycle_
->
ModelState
(
model_name
,
model_version
,
state
);
}
Status
ModelRepositoryManager
::
RepositoryIndex
(
const
bool
ready_only
,
std
::
vector
<
ModelIndex
>*
index
)
{
std
::
set
<
std
::
string
>
seen_models
;
std
::
set
<
std
::
string
>
duplicate_models
;
for
(
const
auto
&
repository_path
:
repository_paths_
)
{
// For any mapped models in this repository, save the mapping
// from their subdirectory name to model name.
std
::
map
<
std
::
string
,
std
::
string
>
models_in_repo
;
for
(
const
auto
&
mapping_it
:
model_mappings_
)
{
if
(
mapping_it
.
second
.
first
==
repository_path
)
{
models_in_repo
.
emplace
(
BaseName
(
mapping_it
.
second
.
second
),
mapping_it
.
first
);
}
}
std
::
set
<
std
::
string
>
subdirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
repository_path
,
&
subdirs
));
for
(
const
auto
&
subdir
:
subdirs
)
{
auto
model
=
subdir
;
auto
model_it
=
models_in_repo
.
find
(
subdir
);
if
(
model_it
!=
models_in_repo
.
end
())
{
model
=
model_it
->
second
;
}
if
(
seen_models
.
find
(
model
)
!=
seen_models
.
end
())
{
duplicate_models
.
insert
(
model
);
}
seen_models
.
insert
(
model
);
}
}
ModelStateMap
states
=
ModelStates
();
for
(
const
auto
&
model
:
seen_models
)
{
// If the same model appears in multiple repostories then show it
// as unavailable since duplicate models are not allowed to load.
if
(
duplicate_models
.
find
(
model
)
!=
duplicate_models
.
end
())
{
index
->
emplace_back
(
model
,
-
1
/* version */
,
ModelReadyState
::
UNAVAILABLE
,
MODEL_READY_REASON_DUPLICATE
);
continue
;
}
// If there is any version/state/reason associated with the model
// then include that in the index.
auto
sitr
=
states
.
find
(
model
);
if
(
sitr
==
states
.
end
())
{
if
(
!
ready_only
)
{
index
->
emplace_back
(
model
);
}
}
else
{
for
(
const
auto
&
pr
:
sitr
->
second
)
{
if
(
!
ready_only
||
(
pr
.
second
.
first
==
ModelReadyState
::
READY
))
{
index
->
emplace_back
(
model
,
pr
.
first
,
pr
.
second
.
first
,
pr
.
second
.
second
);
}
}
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
std
::
shared_ptr
<
Model
>*
model
)
{
Status
status
=
model_life_cycle_
->
GetModel
(
model_name
,
model_version
,
model
);
if
(
!
status
.
IsOk
())
{
model
->
reset
();
status
=
Status
(
status
.
ErrorCode
(),
"Request for unknown model: "
+
status
.
Message
());
}
return
status
;
}
Status
ModelRepositoryManager
::
Poll
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
std
::
set
<
std
::
string
>*
added
,
std
::
set
<
std
::
string
>*
deleted
,
std
::
set
<
std
::
string
>*
modified
,
std
::
set
<
std
::
string
>*
unmodified
,
ModelInfoMap
*
updated_infos
,
bool
*
all_models_polled
)
{
*
all_models_polled
=
true
;
// empty path is the special case to indicate the model should be loaded
// from override file content in 'models'.
std
::
map
<
std
::
string
,
std
::
string
>
model_to_path
;
// If no model is specified, poll all models in all model repositories.
// Otherwise, only poll the specified models
if
(
models
.
empty
())
{
std
::
set
<
std
::
string
>
duplicated_models
;
for
(
const
auto
&
repository_path
:
repository_paths_
)
{
std
::
set
<
std
::
string
>
subdirs
;
Status
status
=
GetDirectorySubdirs
(
repository_path
,
&
subdirs
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll model repository '"
<<
repository_path
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
else
{
for
(
const
auto
&
subdir
:
subdirs
)
{
if
(
!
model_to_path
.
emplace
(
subdir
,
JoinPath
({
repository_path
,
subdir
}))
.
second
)
{
duplicated_models
.
insert
(
subdir
);
*
all_models_polled
=
false
;
}
}
}
}
// If the model is not unique, mark as deleted to unload it
for
(
const
auto
&
model
:
duplicated_models
)
{
model_to_path
.
erase
(
model
);
deleted
->
insert
(
model
);
LOG_ERROR
<<
"failed to poll model '"
<<
model
<<
"': not unique across all model repositories"
;
}
}
// If models are specified, this is explicit model control mode.
else
{
for
(
const
auto
&
model
:
models
)
{
// Skip repository polling if override model files
if
(
ModelDirectoryOverride
(
model
.
second
))
{
model_to_path
.
emplace
(
model
.
first
,
""
);
continue
;
}
// Check model mapping first to see if matching model to load.
bool
exists
=
false
;
auto
model_it
=
model_mappings_
.
find
(
model
.
first
);
if
(
model_it
!=
model_mappings_
.
end
())
{
bool
exists_in_this_repo
=
false
;
auto
full_path
=
model_it
->
second
.
second
;
Status
status
=
FileExists
(
full_path
,
&
exists_in_this_repo
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll mapped path '"
<<
full_path
<<
"' for model '"
<<
model
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
if
(
exists_in_this_repo
)
{
model_to_path
.
emplace
(
model
.
first
,
model_it
->
second
.
second
);
exists
=
true
;
}
else
{
LOG_ERROR
<<
"mapped path '"
<<
full_path
<<
"' does not exist for model '"
<<
model
.
first
<<
"'"
;
exists
=
false
;
}
}
else
{
for
(
const
auto
repository_path
:
repository_paths_
)
{
bool
exists_in_this_repo
=
false
;
const
auto
full_path
=
JoinPath
({
repository_path
,
model
.
first
});
Status
status
=
FileExists
(
full_path
,
&
exists_in_this_repo
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll model repository '"
<<
repository_path
<<
"' for model '"
<<
model
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
else
if
(
exists_in_this_repo
)
{
// Check to make sure this directory is not mapped.
// If mapped, continue to next repository path.
bool
mapped
=
false
;
for
(
auto
const
&
mapping
:
model_mappings_
)
{
if
(
mapping
.
second
.
second
==
full_path
)
{
mapped
=
true
;
break
;
}
}
if
(
mapped
)
{
continue
;
}
auto
res
=
model_to_path
.
emplace
(
model
.
first
,
JoinPath
({
repository_path
,
model
.
first
}));
if
(
res
.
second
)
{
exists
=
true
;
}
else
{
exists
=
false
;
model_to_path
.
erase
(
res
.
first
);
LOG_ERROR
<<
"failed to poll model '"
<<
model
.
first
<<
"': not unique across all model repositories"
;
break
;
}
}
}
}
// For an explicitly specified model that doesn't exist, we don't mark it
// as deleted, we simply mark that we couldn't poll all models.
if
(
!
exists
)
{
*
all_models_polled
=
false
;
}
}
}
// Poll each of the models. If error happens during polling the model,
// its state will fallback to the state before the polling.
for
(
const
auto
&
pair
:
model_to_path
)
{
std
::
unique_ptr
<
ModelInfo
>
model_info
;
const
auto
&
mit
=
models
.
find
(
pair
.
first
);
static
std
::
vector
<
const
InferenceParameter
*>
empty_params
;
auto
status
=
InitializeModelInfo
(
pair
.
first
,
pair
.
second
,
((
mit
==
models
.
end
())
?
empty_params
:
mit
->
second
),
&
model_info
);
const
auto
&
iitr
=
infos_
.
find
(
pair
.
first
);
const
bool
invalid_add
=
(
!
status
.
IsOk
())
&&
(
iitr
==
infos_
.
end
());
if
(
!
invalid_add
)
{
const
auto
&
ret
=
updated_infos
->
emplace
(
pair
.
first
,
nullptr
);
if
(
!
ret
.
second
)
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"unexpected model info for model '"
+
pair
.
first
+
"'"
);
}
// Classify load state and set updated info
if
(
model_info
==
nullptr
)
{
ret
.
first
->
second
.
reset
(
new
ModelInfo
(
*
iitr
->
second
));
unmodified
->
insert
(
pair
.
first
);
}
else
{
ret
.
first
->
second
=
std
::
move
(
model_info
);
if
(
iitr
!=
infos_
.
end
())
{
modified
->
insert
(
pair
.
first
);
}
else
{
added
->
insert
(
pair
.
first
);
}
}
}
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Poll failed for model directory '"
<<
pair
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
}
return
Status
::
Success
;
}
bool
ModelRepositoryManager
::
ModelDirectoryOverride
(
const
std
::
vector
<
const
InferenceParameter
*>&
model_params
)
{
for
(
const
auto
&
param
:
model_params
)
{
if
(
param
->
Name
().
rfind
(
file_prefix
,
0
)
==
0
)
{
// param name starts with prefix if user provides override file
return
true
;
}
}
return
false
;
}
Status
ModelRepositoryManager
::
InitializeModelInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
path
,
const
std
::
vector
<
const
InferenceParameter
*>&
params
,
std
::
unique_ptr
<
ModelInfo
>*
info
)
{
std
::
unique_ptr
<
ModelInfo
>
linfo
(
new
ModelInfo
());
linfo
->
model_path_
=
path
;
bool
unmodified
=
false
;
const
auto
iitr
=
infos_
.
find
(
name
);
// Set 'prev_mtime_ns_' if there is existing ModelInfo
if
(
iitr
!=
infos_
.
end
())
{
linfo
->
prev_mtime_ns_
=
iitr
->
second
->
mtime_nsec_
;
}
else
{
linfo
->
prev_mtime_ns_
=
0
;
}
// Set 'mtime_nsec_' and override 'model_path_' if current path is empty
// (file override is specified)
if
(
linfo
->
model_path_
.
empty
())
{
// Need to localize the override files, use repo agent to manage
// the lifecycle of the localized files
std
::
shared_ptr
<
TritonRepoAgent
>
localize_agent
(
new
LocalizeRepoAgent
());
std
::
unique_ptr
<
TritonRepoAgentModel
>
localize_agent_model
;
RETURN_IF_ERROR
(
TritonRepoAgentModel
::
Create
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
""
,
inference
::
ModelConfig
(),
localize_agent
,
{},
&
localize_agent_model
));
// Set agent model state so the repo agent can access the encoded files
// Using const_cast here but we are safe as the RepoAgent will not
// modify the state
localize_agent_model
->
SetState
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
&
params
)));
RETURN_IF_ERROR
(
localize_agent_model
->
InvokeAgent
(
TRITONREPOAGENT_ACTION_LOAD
));
const
char
*
location
;
TRITONREPOAGENT_ArtifactType
type
;
RETURN_IF_ERROR
(
localize_agent_model
->
Location
(
&
type
,
&
location
));
// For file override, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo
->
mtime_nsec_
=
0
;
linfo
->
model_path_
=
location
;
linfo
->
agent_model_list_
.
reset
(
new
TritonRepoAgentModelList
());
linfo
->
agent_model_list_
->
AddAgentModel
(
std
::
move
(
localize_agent_model
));
}
else
{
if
(
iitr
==
infos_
.
end
())
{
linfo
->
mtime_nsec_
=
GetModifiedTime
(
std
::
string
(
linfo
->
model_path_
));
}
else
{
// Check the current timestamps to determine if model actually has been
// modified
linfo
->
mtime_nsec_
=
linfo
->
prev_mtime_ns_
;
unmodified
=
!
IsModified
(
std
::
string
(
linfo
->
model_path_
),
&
linfo
->
mtime_nsec_
);
}
}
// Set 'model_config_'
bool
parsed_config
=
false
;
// Check if there is config override
for
(
const
auto
&
override_parameter
:
params
)
{
if
((
override_parameter
->
Name
()
==
"config"
)
&&
(
override_parameter
->
Type
()
==
TRITONSERVER_PARAMETER_STRING
))
{
// When override happens, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo
->
mtime_nsec_
=
0
;
unmodified
=
false
;
const
std
::
string
&
override_config
=
override_parameter
->
ValueString
();
auto
err
=
JsonToModelConfig
(
override_config
,
1
/* config_version */
,
&
linfo
->
model_config_
);
if
(
!
err
.
IsOk
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Invalid config override: "
+
std
::
string
(
err
.
Message
()));
}
parsed_config
=
true
;
break
;
}
else
if
(
override_parameter
->
Name
().
rfind
(
file_prefix
,
0
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Unrecognized load parameter '"
+
override_parameter
->
Name
()
+
"' with type '"
+
TRITONSERVER_ParameterTypeString
(
override_parameter
->
Type
())
+
"'"
);
}
}
// Polling model is considered unmodified by this point and can be returned
// with info == nullptr
if
(
unmodified
)
{
return
Status
::
Success
;
}
// Create the associated repo agent models when a model is to be loaded,
// this must be done before normalizing model config as agents might
// redirect to use the model config at a different location
if
(
!
parsed_config
)
{
const
auto
config_path
=
JoinPath
({
linfo
->
model_path_
,
kModelConfigPbTxt
});
bool
model_config_exists
=
false
;
RETURN_IF_ERROR
(
FileExists
(
config_path
,
&
model_config_exists
));
// model config can be missing if auto fill is set
if
(
autofill_
&&
!
model_config_exists
)
{
linfo
->
model_config_
.
Clear
();
}
else
{
RETURN_IF_ERROR
(
ReadTextProto
(
config_path
,
&
linfo
->
model_config_
));
parsed_config
=
true
;
}
}
if
(
parsed_config
)
{
RETURN_IF_ERROR
(
CreateAgentModelListWithLoadAction
(
linfo
->
model_config_
,
linfo
->
model_path_
,
&
linfo
->
agent_model_list_
));
if
(
linfo
->
agent_model_list_
!=
nullptr
)
{
// Get the latest repository path
const
char
*
location
;
TRITONREPOAGENT_ArtifactType
artifact_type
;
RETURN_IF_ERROR
(
linfo
->
agent_model_list_
->
Back
()
->
Location
(
&
artifact_type
,
&
location
));
auto
latest_path
=
std
::
string
(
location
);
linfo
->
model_path_
=
latest_path
;
}
}
linfo
->
is_config_provided_
=
parsed_config
;
// Try to automatically generate missing parts of the model
// configuration (autofill) that don't require model detail
RETURN_IF_ERROR
(
GetNormalizedModelConfig
(
name
,
linfo
->
model_path_
,
min_compute_capability_
,
&
linfo
->
model_config_
));
// Note that the model inputs and outputs are not validated until
// the model model is intialized as they may not be auto-completed
// until model is intialized.
RETURN_IF_ERROR
(
ValidateModelConfig
(
linfo
->
model_config_
,
min_compute_capability_
));
if
(
!
autofill_
)
{
RETURN_IF_ERROR
(
ValidateModelIOConfig
(
linfo
->
model_config_
));
}
// If the model is mapped, update its config name based on the
// mapping.
if
(
model_mappings_
.
find
(
name
)
!=
model_mappings_
.
end
())
{
linfo
->
model_config_
.
set_name
(
name
);
}
else
{
// If there is no model mapping, make sure the name of the model
// matches the name of the directory. This is a somewhat arbitrary
// requirement but seems like good practice to require it of the user.
// It also acts as a check to make sure we don't have two different
// models with the same name.
if
(
linfo
->
model_config_
.
name
()
!=
name
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected directory name '"
+
name
+
"' for model '"
+
linfo
->
model_config_
.
name
()
+
"', directory name must equal model name"
);
}
}
*
info
=
std
::
move
(
linfo
);
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
UpdateDependencyGraph
(
const
std
::
set
<
std
::
string
>&
added
,
const
std
::
set
<
std
::
string
>&
deleted
,
const
std
::
set
<
std
::
string
>&
modified
,
std
::
set
<
std
::
string
>*
deleted_dependents
)
{
// update dependency graph, if the state of a node is changed, all its
// downstreams will be affected
// deleted, drop from dependency_graph, add to missing_nodes if downstreams is
// not empty affected_nodes are all ensembles as only ensembles are depending
// on other models
std
::
set
<
DependencyNode
*>
affected_nodes
;
std
::
set
<
DependencyNode
*>
updated_nodes
;
std
::
set
<
std
::
string
>
current_deleted
=
deleted
;
while
(
!
current_deleted
.
empty
())
{
std
::
set
<
std
::
string
>
next_deleted
;
for
(
const
auto
&
model_name
:
current_deleted
)
{
auto
it
=
dependency_graph_
.
find
(
model_name
);
if
(
it
!=
dependency_graph_
.
end
())
{
// remove this node from its upstreams
for
(
auto
&
upstream
:
it
->
second
->
upstreams_
)
{
upstream
.
first
->
downstreams_
.
erase
(
it
->
second
.
get
());
// Check if the upstream should be removed as well
if
((
deleted_dependents
!=
nullptr
)
&&
(
upstream
.
first
->
downstreams_
.
empty
())
&&
(
!
upstream
.
first
->
explicitly_load_
))
{
next_deleted
.
emplace
(
upstream
.
first
->
model_name_
);
}
}
it
->
second
->
upstreams_
.
clear
();
if
(
!
it
->
second
->
downstreams_
.
empty
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
// mark this node as missing upstream in its downstreams
for
(
auto
&
downstream
:
it
->
second
->
downstreams_
)
{
downstream
->
missing_upstreams_
.
emplace
(
it
->
second
.
get
());
}
missing_nodes_
.
emplace
(
std
::
make_pair
(
model_name
,
std
::
move
(
it
->
second
)));
}
// Make sure deleted node will not be in affected nodes
affected_nodes
.
erase
(
it
->
second
.
get
());
dependency_graph_
.
erase
(
it
);
}
if
(
deleted_dependents
!=
nullptr
)
{
deleted_dependents
->
emplace
(
model_name
);
}
}
current_deleted
.
swap
(
next_deleted
);
}
// modified, invalidate (uncheck) all downstreams
for
(
const
auto
&
model_name
:
modified
)
{
auto
it
=
dependency_graph_
.
find
(
model_name
);
if
(
it
!=
dependency_graph_
.
end
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
ModelInfo
*
info
=
nullptr
;
GetModelInfo
(
model_name
,
&
info
);
it
->
second
->
model_config_
=
info
->
model_config_
;
it
->
second
->
explicitly_load_
=
info
->
explicitly_load_
;
// remove this node from its upstream node
for
(
auto
&
upstream
:
it
->
second
->
upstreams_
)
{
upstream
.
first
->
downstreams_
.
erase
(
it
->
second
.
get
());
}
it
->
second
->
upstreams_
.
clear
();
it
->
second
->
checked_
=
false
;
it
->
second
->
status_
=
Status
::
Success
;
updated_nodes
.
emplace
(
it
->
second
.
get
());
}
}
// added, add to dependency_graph, if in missing_node, invalidate (uncheck)
// and associate all downstreams, remove from missing_node
for
(
const
auto
&
model_name
:
added
)
{
std
::
unique_ptr
<
DependencyNode
>
added_node
;
auto
it
=
missing_nodes_
.
find
(
model_name
);
if
(
it
!=
missing_nodes_
.
end
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
// remove this node from missing upstream node in its downstream nodes
for
(
auto
&
downstream
:
it
->
second
->
downstreams_
)
{
downstream
->
missing_upstreams_
.
erase
(
it
->
second
.
get
());
}
it
->
second
->
checked_
=
false
;
added_node
=
std
::
move
(
it
->
second
);
missing_nodes_
.
erase
(
it
);
}
else
{
// Right now, nothing is going to be filled until validation
added_node
.
reset
(
new
DependencyNode
(
model_name
));
}
ModelInfo
*
info
=
nullptr
;
GetModelInfo
(
model_name
,
&
info
);
added_node
->
model_config_
=
info
->
model_config_
;
added_node
->
explicitly_load_
=
info
->
explicitly_load_
;
updated_nodes
.
emplace
(
added_node
.
get
());
dependency_graph_
.
emplace
(
std
::
make_pair
(
model_name
,
std
::
move
(
added_node
)));
}
auto
&
affected_ensembles
=
affected_nodes
;
for
(
auto
&
updated_node
:
updated_nodes
)
{
bool
is_ensemble
=
ConnectDependencyGraph
(
updated_node
);
if
(
is_ensemble
)
{
affected_ensembles
.
emplace
(
updated_node
);
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// After the dependency graph is updated, check ensemble dependencies
for
(
auto
&
ensemble
:
affected_ensembles
)
{
if
(
ensemble
->
status_
.
IsOk
())
{
if
(
!
ensemble
->
missing_upstreams_
.
empty
())
{
std
::
string
name_list
;
for
(
auto
it
=
ensemble
->
missing_upstreams_
.
begin
();
it
!=
ensemble
->
missing_upstreams_
.
end
();
it
++
)
{
if
(
it
!=
ensemble
->
missing_upstreams_
.
begin
())
{
name_list
+=
", "
;
}
name_list
+=
(
*
it
)
->
model_name_
;
}
ensemble
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble "
+
ensemble
->
model_name_
+
" contains models that are not available: "
+
name_list
);
}
else
{
ensemble
->
status_
=
CircularcyCheck
(
ensemble
,
ensemble
);
}
}
}
#endif // TRITON_ENABLE_ENSEMBLE
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
RegisterModelRepository
(
const
std
::
string
&
repository
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
model_mapping
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"repository registration is not allowed if model control mode is not "
"EXPLICIT"
);
}
bool
is_directory
=
false
;
auto
status
=
IsDirectory
(
repository
,
&
is_directory
);
if
(
!
status
.
IsOk
()
||
!
is_directory
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"failed to register '"
)
+
repository
+
"', repository not found"
)
.
c_str
());
}
{
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
// Check repository and mapped models do not yet exist.
if
(
repository_paths_
.
find
(
repository
)
!=
repository_paths_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"model repository '"
+
repository
+
"' has already been registered"
);
}
for
(
const
auto
&
mapping
:
model_mapping
)
{
if
(
model_mappings_
.
find
(
mapping
.
first
)
!=
model_mappings_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
(
std
::
string
(
"failed to register '"
)
+
mapping
.
first
+
"', there is a conflicting mapping for '"
+
std
::
string
(
mapping
.
first
)
+
"'"
)
.
c_str
());
}
}
repository_paths_
.
emplace
(
repository
);
for
(
const
auto
&
mapping
:
model_mapping
)
{
model_mappings_
.
emplace
(
mapping
.
first
,
std
::
make_pair
(
repository
,
JoinPath
({
repository
,
mapping
.
second
})));
}
}
LOG_INFO
<<
"Model repository registered: "
<<
repository
;
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
UnregisterModelRepository
(
const
std
::
string
&
repository
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"repository unregistration is not allowed if model control mode is not "
"EXPLICIT"
);
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
if
(
repository_paths_
.
erase
(
repository
)
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"failed to unregister '"
+
repository
+
"', repository not found"
);
}
std
::
set
<
std
::
string
>
models_to_delete
;
for
(
auto
const
&
mapping
:
model_mappings_
)
{
if
(
mapping
.
second
.
first
==
repository
)
{
models_to_delete
.
insert
(
mapping
.
first
);
}
}
for
(
auto
const
&
model
:
models_to_delete
)
{
model_mappings_
.
erase
(
model
);
}
}
LOG_INFO
<<
"Model repository unregistered: "
<<
repository
;
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
CircularcyCheck
(
DependencyNode
*
current_node
,
const
DependencyNode
*
start_node
)
{
for
(
auto
&
downstream
:
current_node
->
downstreams_
)
{
if
(
downstream
->
model_name_
==
start_node
->
model_name_
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"circular dependency between ensembles: "
+
start_node
->
model_name_
+
" -> ... -> "
+
current_node
->
model_name_
+
" -> "
+
start_node
->
model_name_
);
}
else
{
const
auto
status
=
CircularcyCheck
(
downstream
,
start_node
);
if
(
!
status
.
IsOk
()
&&
current_node
->
status_
.
IsOk
())
{
current_node
->
status_
=
status
;
return
status
;
}
}
}
return
Status
::
Success
;
}
void
ModelRepositoryManager
::
UncheckDownstream
(
NodeSet
*
downstreams
,
NodeSet
*
updated_nodes
)
{
// Mark downstream nodes as unchecked recursively
for
(
auto
&
node
:
*
downstreams
)
{
if
(
node
->
checked_
)
{
node
->
checked_
=
false
;
node
->
status_
=
Status
::
Success
;
UncheckDownstream
(
&
node
->
downstreams_
,
updated_nodes
);
updated_nodes
->
emplace
(
node
);
}
}
}
bool
ModelRepositoryManager
::
ConnectDependencyGraph
(
DependencyNode
*
updated_node
)
{
// Check the node's model config to determine if it depends on other models
// and if those models are present
updated_node
->
upstreams_
.
clear
();
updated_node
->
missing_upstreams_
.
clear
();
if
(
updated_node
->
model_config_
.
has_ensemble_scheduling
())
{
for
(
const
auto
&
step
:
updated_node
->
model_config_
.
ensemble_scheduling
().
step
())
{
DependencyNode
*
upstream_node
=
nullptr
;
const
auto
&
model_name
=
step
.
model_name
();
auto
dit
=
dependency_graph_
.
find
(
model_name
);
if
(
dit
==
dependency_graph_
.
end
())
{
auto
mit
=
missing_nodes_
.
find
(
model_name
);
if
(
mit
==
missing_nodes_
.
end
())
{
std
::
unique_ptr
<
DependencyNode
>
node
(
new
DependencyNode
(
model_name
));
updated_node
->
missing_upstreams_
.
emplace
(
node
.
get
());
mit
=
missing_nodes_
.
emplace
(
model_name
,
std
::
move
(
node
)).
first
;
}
// Add the node to missing node's downstream so that when the missing
// node is added, the downstreams can be found easily.
mit
->
second
->
downstreams_
.
emplace
(
updated_node
);
upstream_node
=
mit
->
second
.
get
();
}
else
{
dit
->
second
->
downstreams_
.
emplace
(
updated_node
);
upstream_node
=
dit
->
second
.
get
();
}
auto
res
=
updated_node
->
upstreams_
.
emplace
(
upstream_node
,
std
::
set
<
int64_t
>
({
step
.
model_version
()}));
// If map insertion doesn't happen, the same model is required in
// different step, insert the version to existing required version set.
if
(
!
res
.
second
)
{
res
.
first
->
second
.
insert
(
step
.
model_version
());
}
}
return
true
;
}
return
false
;
}
Status
ModelRepositoryManager
::
GetModelInfo
(
const
std
::
string
&
name
,
ModelInfo
**
model_info
)
{
const
auto
itr
=
infos_
.
find
(
name
);
if
(
itr
==
infos_
.
end
())
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"no configuration for model '"
+
name
+
"'"
);
}
*
model_info
=
itr
->
second
.
get
();
return
Status
::
Success
;
}
std
::
pair
<
ModelRepositoryManager
::
NodeSet
,
ModelRepositoryManager
::
NodeSet
>
ModelRepositoryManager
::
ModelsToLoadUnload
(
const
NodeSet
&
loaded_models
)
{
// <valid model set, invalid model set>
std
::
pair
<
NodeSet
,
NodeSet
>
res
;
// first call to this function
if
(
loaded_models
.
empty
())
{
for
(
auto
&
pair
:
dependency_graph_
)
{
auto
node
=
pair
.
second
.
get
();
// only care about nodes that are affected by the update
if
(
!
node
->
checked_
)
{
if
(
CheckNode
(
node
))
{
if
(
node
->
status_
.
IsOk
())
{
res
.
first
.
emplace
(
node
);
}
else
{
res
.
second
.
emplace
(
node
);
}
}
}
}
}
else
{
for
(
const
auto
&
model
:
loaded_models
)
{
for
(
auto
node
:
model
->
downstreams_
)
{
// only care about nodes that are affected by the update
if
(
!
node
->
checked_
)
{
if
(
CheckNode
(
node
))
{
if
(
node
->
status_
.
IsOk
())
{
res
.
first
.
emplace
(
node
);
}
else
{
res
.
second
.
emplace
(
node
);
}
}
}
}
}
}
for
(
auto
&
node
:
res
.
first
)
{
node
->
checked_
=
true
;
}
for
(
auto
&
node
:
res
.
second
)
{
node
->
checked_
=
true
;
}
return
res
;
}
bool
ModelRepositoryManager
::
CheckNode
(
DependencyNode
*
node
)
{
bool
node_ready
=
true
;
// if the node is in invalid status, mark as ready as we know
// it should not be loaded
if
(
node
->
status_
.
IsOk
())
{
for
(
auto
&
upstream
:
node
->
upstreams_
)
{
if
(
!
upstream
.
first
->
checked_
)
{
node_ready
=
false
;
break
;
}
if
(
!
upstream
.
first
->
status_
.
IsOk
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' which is not valid"
);
}
else
if
(
upstream
.
first
->
loaded_versions_
.
empty
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' which has no loaded version"
);
}
else
{
for
(
const
auto
&
required_version
:
upstream
.
second
)
{
if
(
required_version
==
-
1
)
{
continue
;
}
auto
it
=
upstream
.
first
->
loaded_versions_
.
find
(
required_version
);
if
(
it
==
upstream
.
first
->
loaded_versions_
.
end
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' whose required version "
+
std
::
to_string
(
required_version
)
+
" is not loaded"
);
}
}
}
if
(
!
node
->
status_
.
IsOk
())
{
break
;
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// Validate ensemble config if the node is ready. By this point, the
// depending models are loaded and their configs are completed
if
(
node_ready
&&
node
->
status_
.
IsOk
())
{
node
->
status_
=
ValidateEnsembleConfig
(
this
,
node
);
}
#endif // TRITON_ENABLE_ENSEMBLE
}
return
node_ready
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_repository_manager.h
0 → 100644
View file @
0a21fff9
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include <set>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "model_lifecycle.h"
#include "status.h"
#include "triton/common/model_config.h"
namespace
triton
{
namespace
core
{
class
InferenceServer
;
class
Model
;
// [FIXME] should have separated load / unload functions for clarity
enum
ActionType
{
NO_ACTION
,
LOAD
,
UNLOAD
};
/// Predefined reason strings
#define MODEL_READY_REASON_DUPLICATE "model appears in two or more repositories"
/// An object to manage the model repository active in the server.
class
ModelRepositoryManager
{
public:
// Index information for a model.
struct
ModelIndex
{
ModelIndex
(
const
std
::
string
&
n
)
:
name_only_
(
true
),
name_
(
n
),
version_
(
-
1
),
state_
(
ModelReadyState
::
UNKNOWN
)
{
}
ModelIndex
(
const
std
::
string
&
n
,
const
int64_t
v
,
const
ModelReadyState
s
,
const
std
::
string
&
r
)
:
name_only_
(
false
),
name_
(
n
),
version_
(
v
),
state_
(
s
),
reason_
(
r
)
{
}
const
bool
name_only_
;
const
std
::
string
name_
;
const
int64_t
version_
;
const
ModelReadyState
state_
;
const
std
::
string
reason_
;
};
/// A basic unit in dependency graph that records the models seen by the model
/// repository manager.
struct
DependencyNode
{
DependencyNode
(
const
std
::
string
&
model_name
)
:
model_name_
(
model_name
),
status_
(
Status
::
Success
),
checked_
(
false
)
{
}
std
::
string
model_name_
;
Status
status_
;
bool
checked_
;
bool
explicitly_load_
;
inference
::
ModelConfig
model_config_
;
std
::
set
<
int64_t
>
loaded_versions_
;
std
::
set
<
DependencyNode
*>
missing_upstreams_
;
std
::
unordered_map
<
DependencyNode
*
,
std
::
set
<
int64_t
>>
upstreams_
;
std
::
set
<
DependencyNode
*>
downstreams_
;
};
~
ModelRepositoryManager
();
/// Create a manager for a repository.
/// \param server The pointer to the inference server.
/// \param server_version The version of the inference server.
/// \param repository_paths A set of file-system paths of the repositories.
/// \param startup_models A set of models to be loaded at startup
/// if model control is enabled.
/// \param strict_model_config If false attempt to autofill missing required
/// information in each model configuration.
/// \param polling_enabled If true, then PollAndUpdate() is allowed.
/// Otherwise, it is not allowed.
/// \param model_control_enabled If true, then LoadUnloadModel() is allowed
/// and the models in the model repository will not be loaded at startup.
/// Otherwise, LoadUnloadModel() is not allowed and the models will be loaded.
/// Cannot be set to true if polling_enabled is true.
/// \param life_cycle_options The options to configure ModelLifeCycle.
/// \param model_repository_manager Return the model repository manager.
/// \return The error status.
static
Status
Create
(
InferenceServer
*
server
,
const
std
::
string
&
server_version
,
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
std
::
set
<
std
::
string
>&
startup_models
,
const
bool
strict_model_config
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
ModelLifeCycleOptions
&
life_cycle_options
,
std
::
unique_ptr
<
ModelRepositoryManager
>*
model_repository_manager
);
/// Poll the model repository to determine the new set of models and
/// compare with the current set. And serve the new set of models based
/// on their version policy.
Status
PollAndUpdate
();
/// Load or unload a specified model.
/// \param models The models and the parameters to be loaded or unloaded
/// \param type The type action to be performed. If the action is LOAD and
/// the model has been loaded, the model will be re-loaded.
/// \return error status. Return "NOT_FOUND" if it tries to load
/// a non-existing model or if it tries to unload a model that hasn't been
/// loaded.
Status
LoadUnloadModel
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
);
/// Unload all models. This function should be called before shutting down
/// the model repository manager.
/// \return error status.
Status
UnloadAllModels
();
/// Instruct all models to stop accepting new inference requests. However,
/// the models are still capable of processing inference requests
/// if the model considers them as part of the in-flight inference.
/// \return error status.
Status
StopAllModels
();
/// \return the number of in-flight inferences for the all versions of all
/// models. The set element will be a tuple of <model_name, model_version,
/// in-flight inference count>. Note that a model version will not be included
/// if it doesn't have in-flight inferences.
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
InflightStatus
();
/// \param strict_readiness If true, only models that have at least one
/// ready version will be considered as live. Otherwise, the models that
/// have loading / unloading versions will also be live.
/// \return the state of all versions of all live models.
const
ModelStateMap
LiveModelStates
(
bool
strict_readiness
=
false
);
/// \return the state of all versions of all models that have every
/// been (attempted) loaded over the lifetime of the server.
const
ModelStateMap
ModelStates
();
/// \return the states of all versions of a specific model.
const
VersionStateMap
VersionStates
(
const
std
::
string
&
model_name
);
/// \return the ready-state of a specific model version.
Status
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
);
/// Get the index of all models in all repositories.
/// \param ready_only If true return only index of models that are ready.
/// \param index Returns the index.
/// \return error status.
Status
RepositoryIndex
(
const
bool
ready_only
,
std
::
vector
<
ModelIndex
>*
index
);
/// Obtain the specified model.
/// \param model_name The name of the model.
/// \param model_version The version of the model.
/// \param model Return the model object.
/// \return error status.
Status
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
std
::
shared_ptr
<
Model
>*
model
);
// Register model repository path.
/// \param repository Path to model repository.
/// \param model_mapping Mapping with (overridden) model name as key, subdir
/// name as value.
/// \return error status
Status
RegisterModelRepository
(
const
std
::
string
&
repository
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
model_mapping
);
// Unregister model repository path.
/// \param repository Path to model repository.
/// \return error status
Status
UnregisterModelRepository
(
const
std
::
string
&
repository
);
private:
struct
ModelInfo
;
// Map from model name to information about the model.
using
ModelInfoMap
=
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
ModelInfo
>>
;
// Set of DependencyNode
using
NodeSet
=
std
::
set
<
DependencyNode
*>
;
ModelRepositoryManager
(
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
bool
autofill
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
);
/// The internal function that are called in Create() and PollAndUpdate().
Status
PollAndUpdateInternal
(
bool
*
all_models_polled
);
/// The internal function that load or unload a set of models.
Status
LoadUnloadModels
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
,
bool
*
all_models_polled
);
/// Poll the requested models in the model repository and
/// compare with the current set. Return the additions, deletions,
/// and modifications that have occurred. This function will not updated
/// the current model info, it is caller's responsibility to do so.
/// \param models The map from models to be polled to their associated
/// parameters.
/// \param added The names of the models added to the repository.
/// \param deleted The names of the models removed from the repository.
/// \param modified The names of the models remaining in the
/// repository that have been changed.
/// \param unmodified The names of the models remaining in the
/// repository that have not changed.
/// \param updated_infos The model infos retrieved from the poll.
/// \param all_models_polled Return true if all models are polled and
/// their model configuration are validated successfully. Instead of aborting
/// the polling, the models that fail will be ignored and their model infos
/// will stay in the previous state.
/// \return The error status.
Status
Poll
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
std
::
set
<
std
::
string
>*
added
,
std
::
set
<
std
::
string
>*
deleted
,
std
::
set
<
std
::
string
>*
modified
,
std
::
set
<
std
::
string
>*
unmodified
,
ModelInfoMap
*
updated_infos
,
bool
*
all_models_polled
);
/// Helper function for Poll() to initialize ModelInfo for the model.
/// \param name The name of the model.
/// \param path The model path. Empty path means the model is provided via
/// 'params'
/// \param params The model parameters provided for polling model.
/// \param info Return the updated ModelInfo. 'nullptr' will be returned if
/// existing ModelInfo for the model should be reused.
/// \return The error status.
Status
InitializeModelInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
path
,
const
std
::
vector
<
const
InferenceParameter
*>&
params
,
std
::
unique_ptr
<
ModelInfo
>*
info
);
/// Load models based on the dependency graph. The function will iteratively
/// load models that all the models they depend on has been loaded, and unload
/// models if their dependencies are no longer satisfied.
/// \return The status of the model loads.
std
::
map
<
std
::
string
,
Status
>
LoadModelByDependency
();
/// Helper function to update the dependency graph based on the poll result
/// \param added The names of the models added to the repository.
/// \param deleted The names of the models removed from the repository.
/// \param modified The names of the models remaining in the
/// repository that have been changed.
/// \param deleted_dependents The names of dependent models to be removed
/// from the repository.
/// \return The error status.
Status
UpdateDependencyGraph
(
const
std
::
set
<
std
::
string
>&
added
,
const
std
::
set
<
std
::
string
>&
deleted
,
const
std
::
set
<
std
::
string
>&
modified
,
std
::
set
<
std
::
string
>*
deleted_dependents
=
nullptr
);
/// Helper function to uncheck the nodes because the model that they depends
/// on has changed. The unchecked nodes will be validated again.
/// The function will be call recursively to uncheck all downstreams.
/// \param downstreams The nodes to be unchecked.
/// \param updated_nodes Return the nodes that have been unchecked
void
UncheckDownstream
(
NodeSet
*
downstreams
,
NodeSet
*
updated_nodes
);
/// Helper function to construct the edges between nodes in dependency graph.
/// \param updated_node The node that is newly added or modified.
/// \return True if the node represents an ensemble model. False otherwise.
bool
ConnectDependencyGraph
(
DependencyNode
*
updated_node
);
/// Get the model info for a named model.
/// \param name The model name.
/// \param model_info Returns the model information.
/// \return OK if found, NOT_FOUND otherwise.
Status
GetModelInfo
(
const
std
::
string
&
name
,
ModelInfo
**
model_info
);
/// Get the models to be loaded / unloaded based on the model loaded in
/// previous iteration.
/// \param loaded_models The models loaded / unloaded in previous iteration.
/// Unloaded models will be represented as models with no loaded versions.
/// \return A pair of node set containing models to be loaded and models to be
/// unloaded for the next iteration.
std
::
pair
<
NodeSet
,
NodeSet
>
ModelsToLoadUnload
(
const
NodeSet
&
loaded_models
);
/// Check if the node is ready for the next iteration. A node is ready if the
/// node is invalid (containing invalid model config or its depdencies failed
/// to load) or all of its dependencies are satisfied.
/// \param node The node to be checked.
/// \return True if the node is ready. False otherwise.
bool
CheckNode
(
DependencyNode
*
node
);
Status
CircularcyCheck
(
DependencyNode
*
current_node
,
const
DependencyNode
*
start_node
);
bool
ModelDirectoryOverride
(
const
std
::
vector
<
const
InferenceParameter
*>&
model_params
);
std
::
set
<
std
::
string
>
repository_paths_
;
const
bool
autofill_
;
const
bool
polling_enabled_
;
const
bool
model_control_enabled_
;
const
double
min_compute_capability_
;
std
::
mutex
poll_mu_
;
ModelInfoMap
infos_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
DependencyNode
>>
dependency_graph_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
DependencyNode
>>
missing_nodes_
;
// Mappings from (overridden) model names to a pair of their repository and
// absolute path
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
std
::
string
,
std
::
string
>>
model_mappings_
;
std
::
unique_ptr
<
ModelLifeCycle
>
model_life_cycle_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/numa_utils.cc
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "numa_utils.h"
#ifndef _WIN32
#include <numa.h>
#include <numaif.h>
#endif
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
namespace
{
std
::
string
VectorToString
(
const
std
::
vector
<
int
>&
vec
)
{
std
::
string
str
(
"["
);
for
(
const
auto
&
element
:
vec
)
{
str
+=
std
::
to_string
(
element
);
str
+=
","
;
}
str
+=
"]"
;
return
str
;
}
Status
ParseIntOption
(
const
std
::
string
&
msg
,
const
std
::
string
&
arg
,
int
*
value
)
{
try
{
*
value
=
std
::
stoi
(
arg
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
msg
+
": Can't parse '"
+
arg
+
"' to integer"
);
}
return
Status
::
Success
;
}
}
// namespace
// NUMA setting will be ignored on Windows platform
#ifdef _WIN32
Status
SetNumaConfigOnThread
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
return
Status
::
Success
;
}
Status
SetNumaMemoryPolicy
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
return
Status
::
Success
;
}
Status
GetNumaMemoryPolicyNodeMask
(
unsigned
long
*
node_mask
)
{
*
node_mask
=
0
;
return
Status
::
Success
;
}
Status
ResetNumaMemoryPolicy
()
{
return
Status
::
Success
;
}
Status
SetNumaThreadAffinity
(
std
::
thread
::
native_handle_type
thread
,
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
return
Status
::
Success
;
}
#else
// Use variable to make sure no NUMA related function is actually called
// if Triton is not running with NUMA awareness. i.e. Extra docker permission
// is needed to call the NUMA functions and this ensures backward compatibility.
thread_local
bool
numa_set
=
false
;
Status
SetNumaConfigOnThread
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
// Set thread affinity
RETURN_IF_ERROR
(
SetNumaThreadAffinity
(
pthread_self
(),
host_policy
));
// Set memory policy
RETURN_IF_ERROR
(
SetNumaMemoryPolicy
(
host_policy
));
return
Status
::
Success
;
}
Status
SetNumaMemoryPolicy
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
const
auto
it
=
host_policy
.
find
(
"numa-node"
);
if
(
it
!=
host_policy
.
end
())
{
int
node_id
;
RETURN_IF_ERROR
(
ParseIntOption
(
"Parsing 'numa-node' value"
,
it
->
second
,
&
node_id
));
LOG_VERBOSE
(
1
)
<<
"Thread is binding to NUMA node "
<<
it
->
second
<<
". Max NUMA node count: "
<<
(
numa_max_node
()
+
1
);
numa_set
=
true
;
unsigned
long
node_mask
=
1UL
<<
node_id
;
if
(
set_mempolicy
(
MPOL_BIND
,
&
node_mask
,
(
numa_max_node
()
+
1
)
+
1
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to set NUMA memory policy: "
)
+
strerror
(
errno
));
}
}
return
Status
::
Success
;
}
Status
GetNumaMemoryPolicyNodeMask
(
unsigned
long
*
node_mask
)
{
*
node_mask
=
0
;
int
mode
;
if
(
numa_set
&&
get_mempolicy
(
&
mode
,
node_mask
,
numa_max_node
()
+
1
,
NULL
,
0
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to get NUMA node for current thread: "
)
+
strerror
(
errno
));
}
return
Status
::
Success
;
}
Status
ResetNumaMemoryPolicy
()
{
if
(
numa_set
&&
(
set_mempolicy
(
MPOL_DEFAULT
,
nullptr
,
0
)
!=
0
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to reset NUMA memory policy: "
)
+
strerror
(
errno
));
}
numa_set
=
false
;
return
Status
::
Success
;
}
Status
SetNumaThreadAffinity
(
std
::
thread
::
native_handle_type
thread
,
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
)
{
const
auto
it
=
host_policy
.
find
(
"cpu-cores"
);
if
(
it
!=
host_policy
.
end
())
{
// Parse CPUs
std
::
vector
<
int
>
cpus
;
{
const
auto
&
cpu_str
=
it
->
second
;
auto
delim_cpus
=
cpu_str
.
find
(
","
);
int
current_pos
=
0
;
while
(
true
)
{
auto
delim_range
=
cpu_str
.
find
(
"-"
,
current_pos
);
if
(
delim_range
==
std
::
string
::
npos
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
std
::
string
(
"host policy setting 'cpu-cores' format is "
"'<lower_cpu_core_id>-<upper_cpu_core_id>'. Got "
)
+
cpu_str
.
substr
(
current_pos
,
((
delim_cpus
==
std
::
string
::
npos
)
?
(
cpu_str
.
length
()
+
1
)
:
delim_cpus
)
-
current_pos
));
}
int
lower
,
upper
;
RETURN_IF_ERROR
(
ParseIntOption
(
"Parsing 'cpu-cores' value"
,
cpu_str
.
substr
(
current_pos
,
delim_range
-
current_pos
),
&
lower
));
RETURN_IF_ERROR
(
ParseIntOption
(
"Parsing 'cpu-cores' value"
,
(
delim_cpus
==
std
::
string
::
npos
)
?
cpu_str
.
substr
(
delim_range
+
1
)
:
cpu_str
.
substr
(
delim_range
+
1
,
delim_cpus
-
(
delim_range
+
1
)),
&
upper
));
for
(;
lower
<=
upper
;
++
lower
)
{
cpus
.
push_back
(
lower
);
}
// break if the processed range is the last specified range
if
(
delim_cpus
!=
std
::
string
::
npos
)
{
current_pos
=
delim_cpus
+
1
;
delim_cpus
=
cpu_str
.
find
(
","
,
current_pos
);
}
else
{
break
;
}
}
}
LOG_VERBOSE
(
1
)
<<
"Thread is binding to one of the CPUs: "
<<
VectorToString
(
cpus
);
numa_set
=
true
;
cpu_set_t
cpuset
;
CPU_ZERO
(
&
cpuset
);
for
(
int
cpu
:
cpus
)
{
CPU_SET
(
cpu
,
&
cpuset
);
}
if
(
pthread_setaffinity_np
(
thread
,
sizeof
(
cpu_set_t
),
&
cpuset
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
std
::
string
(
"Unable to set NUMA thread affinity: "
)
+
strerror
(
errno
));
}
}
return
Status
::
Success
;
}
#endif
}}
// namespace triton::core
3rdparty/core-r22.12/src/numa_utils.h
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <map>
#include <thread>
#include <vector>
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
// Helper function to set memory policy and thread affinity on current thread
Status
SetNumaConfigOnThread
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
);
// Restrict the memory allocation to specific NUMA node.
Status
SetNumaMemoryPolicy
(
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
);
// Retrieve the node mask used to set memory policy for the current thread
Status
GetNumaMemoryPolicyNodeMask
(
unsigned
long
*
node_mask
);
// Reset the memory allocation setting.
Status
ResetNumaMemoryPolicy
();
// Set a thread affinity to be on specific cpus.
Status
SetNumaThreadAffinity
(
std
::
thread
::
native_handle_type
thread
,
const
triton
::
common
::
HostPolicyCmdlineConfig
&
host_policy
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/payload.cc
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "payload.h"
namespace
triton
{
namespace
core
{
Payload
::
Payload
()
:
op_type_
(
Operation
::
INFER_RUN
),
requests_
(
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>
()),
OnCallback_
([]()
{}),
instance_
(
nullptr
),
state_
(
State
::
UNINITIALIZED
),
batcher_start_ns_
(
0
),
saturated_
(
false
)
{
exec_mu_
.
reset
(
new
std
::
mutex
());
}
const
Status
&
Payload
::
MergePayload
(
std
::
shared_ptr
<
Payload
>&
payload
)
{
if
((
payload
->
GetOpType
()
!=
Operation
::
INFER_RUN
)
||
(
op_type_
!=
Operation
::
INFER_RUN
))
{
static
Status
op_type_error
(
Status
::
Code
::
INTERNAL
,
"Attempted to merge payloads of type that are not INFER_RUN"
);
return
op_type_error
;
}
if
(
payload
->
GetInstance
()
!=
instance_
)
{
static
Status
instance_error
(
Status
::
Code
::
INTERNAL
,
"Attempted to merge payloads of mismatching instance"
);
return
instance_error
;
}
if
((
payload
->
GetState
()
!=
State
::
EXECUTING
)
||
(
state_
!=
State
::
EXECUTING
))
{
static
Status
state_error
(
Status
::
Code
::
INTERNAL
,
"Attempted to merge payloads that are not in executing state"
);
return
state_error
;
}
// Skip comparison if not initialized (required), here assume either all
// payloads are initialized or otherwise.
if
(
required_equal_inputs_
.
Initialized
()
&&
!
required_equal_inputs_
.
HasEqualInputs
(
*
payload
->
Requests
().
begin
()))
{
static
Status
shape_error
(
Status
::
Code
::
INVALID_ARG
,
"Attempted to merge payloads that has non-equal inputs"
);
return
shape_error
;
}
requests_
.
insert
(
requests_
.
end
(),
std
::
make_move_iterator
(
payload
->
Requests
().
begin
()),
std
::
make_move_iterator
(
payload
->
Requests
().
end
()));
payload
->
Callback
();
return
Status
::
Success
;
}
void
Payload
::
Reset
(
const
Operation
op_type
,
TritonModelInstance
*
instance
)
{
op_type_
=
op_type
;
requests_
.
clear
();
OnCallback_
=
[]()
{};
release_callbacks_
.
clear
();
instance_
=
instance
;
state_
=
State
::
UNINITIALIZED
;
status_
.
reset
(
new
std
::
promise
<
Status
>
());
required_equal_inputs_
=
RequiredEqualInputs
();
batcher_start_ns_
=
0
;
saturated_
=
false
;
}
void
Payload
::
Release
()
{
op_type_
=
Operation
::
INFER_RUN
;
requests_
.
clear
();
OnCallback_
=
[]()
{};
release_callbacks_
.
clear
();
instance_
=
nullptr
;
state_
=
State
::
RELEASED
;
required_equal_inputs_
=
RequiredEqualInputs
();
batcher_start_ns_
=
0
;
saturated_
=
false
;
}
size_t
Payload
::
BatchSize
()
{
size_t
batch_size
=
0
;
for
(
const
auto
&
request
:
requests_
)
{
batch_size
+=
std
::
max
(
1U
,
request
->
BatchSize
());
}
return
batch_size
;
}
void
Payload
::
ReserveRequests
(
size_t
size
)
{
requests_
.
reserve
(
size
);
}
void
Payload
::
AddRequest
(
std
::
unique_ptr
<
InferenceRequest
>
request
)
{
if
((
batcher_start_ns_
==
0
)
||
(
batcher_start_ns_
>
request
->
BatcherStartNs
()))
{
batcher_start_ns_
=
request
->
BatcherStartNs
();
}
requests_
.
push_back
(
std
::
move
(
request
));
}
void
Payload
::
SetCallback
(
std
::
function
<
void
()
>
OnCallback
)
{
OnCallback_
=
OnCallback
;
}
void
Payload
::
SetInstance
(
TritonModelInstance
*
model_instance
)
{
instance_
=
model_instance
;
}
void
Payload
::
AddInternalReleaseCallback
(
std
::
function
<
void
()
>&&
callback
)
{
release_callbacks_
.
emplace_back
(
std
::
move
(
callback
));
}
void
Payload
::
MarkSaturated
()
{
saturated_
=
true
;
}
void
Payload
::
SetState
(
Payload
::
State
state
)
{
state_
=
state
;
}
Status
Payload
::
Wait
()
{
return
status_
->
get_future
().
get
();
}
void
Payload
::
Callback
()
{
OnCallback_
();
}
void
Payload
::
OnRelease
()
{
// Invoke the release callbacks added internally before releasing the
// request to user provided callback.
for
(
auto
it
=
release_callbacks_
.
rbegin
();
it
!=
release_callbacks_
.
rend
();
it
++
)
{
(
*
it
)();
}
release_callbacks_
.
clear
();
}
void
Payload
::
Execute
(
bool
*
should_exit
)
{
*
should_exit
=
false
;
Status
status
;
switch
(
op_type_
)
{
case
Operation
::
INFER_RUN
:
instance_
->
Schedule
(
std
::
move
(
requests_
),
OnCallback_
);
break
;
case
Operation
::
INIT
:
status
=
instance_
->
Initialize
();
break
;
case
Operation
::
WARM_UP
:
status
=
instance_
->
WarmUp
();
break
;
case
Operation
::
EXIT
:
*
should_exit
=
true
;
}
status_
->
set_value
(
status
);
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/payload.h
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <vector>
#include "backend_model_instance.h"
#include "infer_request.h"
#include "scheduler_utils.h"
#include "status.h"
namespace
triton
{
namespace
core
{
class
Payload
{
public:
enum
Operation
{
INFER_RUN
=
0
,
INIT
=
1
,
WARM_UP
=
2
,
EXIT
=
3
};
enum
State
{
UNINITIALIZED
=
0
,
READY
=
1
,
REQUESTED
=
2
,
SCHEDULED
=
3
,
EXECUTING
=
4
,
RELEASED
=
5
};
Payload
();
void
Reset
(
const
Operation
op_type
,
TritonModelInstance
*
instance
=
nullptr
);
const
Status
&
MergePayload
(
std
::
shared_ptr
<
Payload
>&
payload
);
Operation
GetOpType
()
{
return
op_type_
;
}
std
::
mutex
*
GetExecMutex
()
{
return
exec_mu_
.
get
();
}
size_t
RequestCount
()
{
return
requests_
.
size
();
}
size_t
BatchSize
();
void
ReserveRequests
(
size_t
size
);
void
AddRequest
(
std
::
unique_ptr
<
InferenceRequest
>
request
);
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>&
Requests
()
{
return
requests_
;
}
uint64_t
BatcherStartNs
()
{
return
batcher_start_ns_
;
}
void
SetCallback
(
std
::
function
<
void
()
>
OnCallback
);
void
Callback
();
void
AddInternalReleaseCallback
(
std
::
function
<
void
()
>&&
callback
);
void
OnRelease
();
void
SetInstance
(
TritonModelInstance
*
model_instance
);
TritonModelInstance
*
GetInstance
()
{
return
instance_
;
}
void
MarkSaturated
();
bool
IsSaturated
()
{
return
saturated_
;
}
RequiredEqualInputs
*
MutableRequiredEqualInputs
()
{
return
&
required_equal_inputs_
;
}
State
GetState
()
{
return
state_
;
}
void
SetState
(
State
state
);
void
Execute
(
bool
*
should_exit
);
Status
Wait
();
void
Release
();
private:
Operation
op_type_
;
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>
requests_
;
std
::
function
<
void
()
>
OnCallback_
;
std
::
vector
<
std
::
function
<
void
()
>>
release_callbacks_
;
TritonModelInstance
*
instance_
;
State
state_
;
std
::
unique_ptr
<
std
::
promise
<
Status
>>
status_
;
std
::
unique_ptr
<
std
::
mutex
>
exec_mu_
;
uint64_t
batcher_start_ns_
;
RequiredEqualInputs
required_equal_inputs_
;
bool
saturated_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/pinned_memory_manager.cc
0 → 100644
View file @
0a21fff9
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "pinned_memory_manager.h"
#include <sstream>
#include "numa_utils.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace
triton
{
namespace
core
{
namespace
{
std
::
string
PointerToString
(
void
*
ptr
)
{
std
::
stringstream
ss
;
ss
<<
ptr
;
return
ss
.
str
();
}
Status
ParseIntOption
(
const
std
::
string
&
msg
,
const
std
::
string
&
arg
,
int
*
value
)
{
try
{
*
value
=
std
::
stoi
(
arg
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
msg
+
": Can't parse '"
+
arg
+
"' to integer"
);
}
return
Status
::
Success
;
}
}
// namespace
std
::
unique_ptr
<
PinnedMemoryManager
>
PinnedMemoryManager
::
instance_
;
uint64_t
PinnedMemoryManager
::
pinned_memory_byte_size_
;
PinnedMemoryManager
::
PinnedMemory
::
PinnedMemory
(
void
*
pinned_memory_buffer
,
uint64_t
size
)
:
pinned_memory_buffer_
(
pinned_memory_buffer
)
{
if
(
pinned_memory_buffer_
!=
nullptr
)
{
managed_pinned_memory_
=
boost
::
interprocess
::
managed_external_buffer
(
boost
::
interprocess
::
create_only_t
{},
pinned_memory_buffer_
,
size
);
}
}
PinnedMemoryManager
::
PinnedMemory
::~
PinnedMemory
()
{
#ifdef TRITON_ENABLE_GPU
if
(
pinned_memory_buffer_
!=
nullptr
)
{
cudaFreeHost
(
pinned_memory_buffer_
);
}
#endif // TRITON_ENABLE_GPU
}
PinnedMemoryManager
::~
PinnedMemoryManager
()
{
// Clean up
for
(
const
auto
&
memory_info
:
memory_info_
)
{
const
auto
&
is_pinned
=
memory_info
.
second
.
first
;
if
(
!
is_pinned
)
{
free
(
memory_info
.
first
);
}
}
}
void
PinnedMemoryManager
::
AddPinnedMemoryBuffer
(
const
std
::
shared_ptr
<
PinnedMemory
>&
pinned_memory_buffer
,
unsigned
long
node_mask
)
{
pinned_memory_buffers_
[
node_mask
]
=
pinned_memory_buffer
;
}
Status
PinnedMemoryManager
::
AllocInternal
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
,
PinnedMemory
*
pinned_memory_buffer
)
{
auto
status
=
Status
::
Success
;
if
(
pinned_memory_buffer
->
pinned_memory_buffer_
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
pinned_memory_buffer
->
buffer_mtx_
);
*
ptr
=
pinned_memory_buffer
->
managed_pinned_memory_
.
allocate
(
size
,
std
::
nothrow_t
{});
*
allocated_type
=
TRITONSERVER_MEMORY_CPU_PINNED
;
if
(
*
ptr
==
nullptr
)
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate pinned system memory"
);
}
}
else
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate pinned system memory: no pinned memory pool"
);
}
bool
is_pinned
=
true
;
if
((
!
status
.
IsOk
())
&&
allow_nonpinned_fallback
)
{
static
bool
warning_logged
=
false
;
if
(
!
warning_logged
)
{
LOG_WARNING
<<
status
.
Message
()
<<
", falling back to non-pinned system memory"
;
warning_logged
=
true
;
}
*
ptr
=
malloc
(
size
);
*
allocated_type
=
TRITONSERVER_MEMORY_CPU
;
is_pinned
=
false
;
if
(
*
ptr
==
nullptr
)
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate non-pinned system memory"
);
}
else
{
status
=
Status
::
Success
;
}
}
// keep track of allocated buffer or clean up
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
info_mtx_
);
if
(
status
.
IsOk
())
{
auto
res
=
memory_info_
.
emplace
(
*
ptr
,
std
::
make_pair
(
is_pinned
,
pinned_memory_buffer
));
if
(
!
res
.
second
)
{
status
=
Status
(
Status
::
Code
::
INTERNAL
,
"unexpected memory address collision, '"
+
PointerToString
(
*
ptr
)
+
"' has been managed"
);
}
LOG_VERBOSE
(
1
)
<<
(
is_pinned
?
""
:
"non-"
)
<<
"pinned memory allocation: "
<<
"size "
<<
size
<<
", addr "
<<
*
ptr
;
}
}
if
((
!
status
.
IsOk
())
&&
(
*
ptr
!=
nullptr
))
{
if
(
is_pinned
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
pinned_memory_buffer
->
buffer_mtx_
);
pinned_memory_buffer
->
managed_pinned_memory_
.
deallocate
(
*
ptr
);
}
else
{
free
(
*
ptr
);
}
}
return
status
;
}
Status
PinnedMemoryManager
::
FreeInternal
(
void
*
ptr
)
{
bool
is_pinned
=
true
;
PinnedMemory
*
pinned_memory_buffer
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
info_mtx_
);
auto
it
=
memory_info_
.
find
(
ptr
);
if
(
it
!=
memory_info_
.
end
())
{
is_pinned
=
it
->
second
.
first
;
pinned_memory_buffer
=
it
->
second
.
second
;
LOG_VERBOSE
(
1
)
<<
(
is_pinned
?
""
:
"non-"
)
<<
"pinned memory deallocation: "
<<
"addr "
<<
ptr
;
memory_info_
.
erase
(
it
);
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"unexpected memory address '"
+
PointerToString
(
ptr
)
+
"' is not being managed"
);
}
}
if
(
is_pinned
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
pinned_memory_buffer
->
buffer_mtx_
);
pinned_memory_buffer
->
managed_pinned_memory_
.
deallocate
(
ptr
);
}
else
{
free
(
ptr
);
}
return
Status
::
Success
;
}
void
PinnedMemoryManager
::
Reset
()
{
instance_
.
reset
();
}
Status
PinnedMemoryManager
::
Create
(
const
Options
&
options
)
{
if
(
instance_
!=
nullptr
)
{
LOG_WARNING
<<
"New pinned memory pool of size "
<<
options
.
pinned_memory_pool_byte_size_
<<
" could not be created since one already exists"
<<
" of size "
<<
pinned_memory_byte_size_
;
return
Status
::
Success
;
}
instance_
.
reset
(
new
PinnedMemoryManager
());
if
(
options
.
host_policy_map_
.
empty
())
{
void
*
buffer
=
nullptr
;
#ifdef TRITON_ENABLE_GPU
auto
err
=
cudaHostAlloc
(
&
buffer
,
options
.
pinned_memory_pool_byte_size_
,
cudaHostAllocPortable
);
if
(
err
!=
cudaSuccess
)
{
buffer
=
nullptr
;
LOG_WARNING
<<
"Unable to allocate pinned system memory, pinned memory "
"pool will not be available: "
<<
std
::
string
(
cudaGetErrorString
(
err
));
}
else
if
(
options
.
pinned_memory_pool_byte_size_
!=
0
)
{
LOG_INFO
<<
"Pinned memory pool is created at '"
<<
PointerToString
(
buffer
)
<<
"' with size "
<<
options
.
pinned_memory_pool_byte_size_
;
}
else
{
LOG_INFO
<<
"Pinned memory pool disabled"
;
}
#endif // TRITON_ENABLE_GPU
try
{
instance_
->
AddPinnedMemoryBuffer
(
std
::
shared_ptr
<
PinnedMemory
>
(
new
PinnedMemory
(
buffer
,
options
.
pinned_memory_pool_byte_size_
)),
0
);
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to add Pinned Memory buffer: "
+
std
::
string
(
ex
.
what
()));
}
}
else
{
// Create only one buffer / manager should be created for one node,
// and all associated devices should request memory from the shared manager
std
::
map
<
int32_t
,
std
::
string
>
numa_map
;
for
(
const
auto
host_policy
:
options
.
host_policy_map_
)
{
const
auto
numa_it
=
host_policy
.
second
.
find
(
"numa-node"
);
if
(
numa_it
!=
host_policy
.
second
.
end
())
{
int32_t
numa_id
;
if
(
ParseIntOption
(
"Parsing NUMA node"
,
numa_it
->
second
,
&
numa_id
)
.
IsOk
())
{
numa_map
.
emplace
(
numa_id
,
host_policy
.
first
);
}
}
}
for
(
const
auto
node_policy
:
numa_map
)
{
auto
status
=
SetNumaMemoryPolicy
(
options
.
host_policy_map_
.
at
(
node_policy
.
second
));
if
(
!
status
.
IsOk
())
{
LOG_WARNING
<<
"Unable to allocate pinned system memory for NUMA node "
<<
node_policy
.
first
<<
": "
<<
status
.
AsString
();
continue
;
}
unsigned
long
node_mask
;
status
=
GetNumaMemoryPolicyNodeMask
(
&
node_mask
);
if
(
!
status
.
IsOk
())
{
LOG_WARNING
<<
"Unable to get NUMA node set for current thread: "
<<
status
.
AsString
();
continue
;
}
void
*
buffer
=
nullptr
;
#ifdef TRITON_ENABLE_GPU
auto
err
=
cudaHostAlloc
(
&
buffer
,
options
.
pinned_memory_pool_byte_size_
,
cudaHostAllocPortable
);
if
(
err
!=
cudaSuccess
)
{
buffer
=
nullptr
;
LOG_WARNING
<<
"Unable to allocate pinned system memory, pinned memory "
"pool will not be available: "
<<
std
::
string
(
cudaGetErrorString
(
err
));
}
else
if
(
options
.
pinned_memory_pool_byte_size_
!=
0
)
{
LOG_INFO
<<
"Pinned memory pool is created at '"
<<
PointerToString
(
buffer
)
<<
"' with size "
<<
options
.
pinned_memory_pool_byte_size_
;
}
else
{
LOG_INFO
<<
"Pinned memory pool disabled"
;
}
#endif // TRITON_ENABLE_GPU
ResetNumaMemoryPolicy
();
try
{
instance_
->
AddPinnedMemoryBuffer
(
std
::
shared_ptr
<
PinnedMemory
>
(
new
PinnedMemory
(
buffer
,
options
.
pinned_memory_pool_byte_size_
)),
node_mask
);
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to add Pinned Memory buffer with host policy: "
+
std
::
string
(
ex
.
what
()));
}
}
// If no pinned memory is allocated, add an empty entry where all allocation
// will be on normal system memory
if
(
instance_
->
pinned_memory_buffers_
.
empty
())
{
try
{
instance_
->
AddPinnedMemoryBuffer
(
std
::
shared_ptr
<
PinnedMemory
>
(
new
PinnedMemory
(
nullptr
,
options
.
pinned_memory_pool_byte_size_
)),
0
);
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to add empty Pinned Memory entry: "
+
std
::
string
(
ex
.
what
()));
}
}
}
pinned_memory_byte_size_
=
options
.
pinned_memory_pool_byte_size_
;
return
Status
::
Success
;
}
Status
PinnedMemoryManager
::
Alloc
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
)
{
if
(
instance_
==
nullptr
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"PinnedMemoryManager has not been created"
);
}
auto
pinned_memory_buffer
=
instance_
->
pinned_memory_buffers_
.
begin
()
->
second
.
get
();
if
(
instance_
->
pinned_memory_buffers_
.
size
()
>
1
)
{
unsigned
long
node_mask
;
if
(
GetNumaMemoryPolicyNodeMask
(
&
node_mask
).
IsOk
())
{
auto
it
=
instance_
->
pinned_memory_buffers_
.
find
(
node_mask
);
if
(
it
!=
instance_
->
pinned_memory_buffers_
.
end
())
{
pinned_memory_buffer
=
it
->
second
.
get
();
}
}
}
return
instance_
->
AllocInternal
(
ptr
,
size
,
allocated_type
,
allow_nonpinned_fallback
,
pinned_memory_buffer
);
}
Status
PinnedMemoryManager
::
Free
(
void
*
ptr
)
{
if
(
instance_
==
nullptr
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"PinnedMemoryManager has not been created"
);
}
return
instance_
->
FreeInternal
(
ptr
);
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/pinned_memory_manager.h
0 → 100644
View file @
0a21fff9
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <boost/interprocess/managed_external_buffer.hpp>
#include <map>
#include <memory>
#include <mutex>
#include "status.h"
#include "triton/common/model_config.h"
namespace
triton
{
namespace
core
{
// This is a singleton class responsible for maintaining pinned memory pool
// used by the inference server. Pinned memory allocations and deallocations
// must be requested via functions provided by this class.
class
PinnedMemoryManager
{
public:
// Options to configure pinned memeory manager.
struct
Options
{
Options
(
uint64_t
b
=
0
,
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map
=
{})
:
pinned_memory_pool_byte_size_
(
b
),
host_policy_map_
(
host_policy_map
)
{
}
uint64_t
pinned_memory_pool_byte_size_
;
triton
::
common
::
HostPolicyCmdlineConfigMap
host_policy_map_
;
};
~
PinnedMemoryManager
();
// Create the pinned memory manager based on 'options' specified.
// Return Status object indicating success or failure.
static
Status
Create
(
const
Options
&
options
);
// Allocate pinned memory with the requested 'size' and return the pointer
// in 'ptr'. If 'allow_nonpinned_fallback' is true, regular system memory
// will be allocated as fallback in the case where pinned memory fails to
// be allocated.
// Return Status object indicating success or failure.
static
Status
Alloc
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
);
// Free the memory allocated by the pinned memory manager.
// Return Status object indicating success or failure.
static
Status
Free
(
void
*
ptr
);
protected:
// Provide explicit control on the lifecycle of the CUDA memory manager,
// for testing only.
static
void
Reset
();
private:
class
PinnedMemory
{
public:
PinnedMemory
(
void
*
pinned_memory_buffer
,
uint64_t
size
);
~
PinnedMemory
();
void
*
pinned_memory_buffer_
;
std
::
mutex
buffer_mtx_
;
boost
::
interprocess
::
managed_external_buffer
managed_pinned_memory_
;
};
PinnedMemoryManager
()
=
default
;
Status
AllocInternal
(
void
**
ptr
,
uint64_t
size
,
TRITONSERVER_MemoryType
*
allocated_type
,
bool
allow_nonpinned_fallback
,
PinnedMemory
*
pinned_memory_buffer
);
Status
FreeInternal
(
void
*
ptr
);
void
AddPinnedMemoryBuffer
(
const
std
::
shared_ptr
<
PinnedMemory
>&
pinned_memory_buffer
,
unsigned
long
node_mask
);
static
std
::
unique_ptr
<
PinnedMemoryManager
>
instance_
;
static
uint64_t
pinned_memory_byte_size_
;
std
::
mutex
info_mtx_
;
std
::
map
<
void
*
,
std
::
pair
<
bool
,
PinnedMemory
*>>
memory_info_
;
std
::
map
<
unsigned
long
,
std
::
shared_ptr
<
PinnedMemory
>>
pinned_memory_buffers_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/rate_limiter.cc
0 → 100644
View file @
0a21fff9
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rate_limiter.h"
#include <limits>
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
constexpr
size_t
MAX_PAYLOAD_BUCKET_COUNT
=
1000
;
//=========================================================================
// Core Implementation
//=========================================================================
Status
RateLimiter
::
Create
(
const
bool
ignore_resources_and_priority
,
const
RateLimiter
::
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
RateLimiter
>*
rate_limiter
)
{
std
::
unique_ptr
<
RateLimiter
>
local_rate_limiter
(
new
RateLimiter
(
ignore_resources_and_priority
,
resource_map
));
*
rate_limiter
=
std
::
move
(
local_rate_limiter
);
return
Status
::
Success
;
}
Status
RateLimiter
::
RegisterModelInstance
(
TritonModelInstance
*
triton_model_instance
,
const
RateLimiterConfig
&
rate_limiter_config
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_ctx_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
model_instance_ctx_mtx_
);
auto
&
model_context
=
model_contexts_
[
triton_model_instance
->
Model
()];
auto
&
model_instances
=
model_instance_ctxs_
[
triton_model_instance
->
Model
()];
model_instances
.
push_back
(
std
::
shared_ptr
<
ModelInstanceContext
>
(
new
ModelInstanceContext
(
triton_model_instance
,
&
model_context
,
rate_limiter_config
,
[
this
](
ModelInstanceContext
*
instance
)
{
OnStage
(
instance
);
},
[
this
](
ModelInstanceContext
*
instance
)
{
OnRelease
(
instance
);
})));
model_context
.
AddAvailableInstance
(
model_instances
.
back
().
get
());
model_context
.
AddSpecificRequestQueue
();
if
(
!
ignore_resources_and_priority_
)
{
resource_manager_
->
AddModelInstance
(
model_instances
.
back
().
get
());
RETURN_IF_ERROR
(
resource_manager_
->
UpdateResourceLimits
());
}
}
InitializePayloadQueues
(
triton_model_instance
);
return
Status
::
Success
;
}
Status
RateLimiter
::
UnregisterModel
(
const
TritonModel
*
model
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_ctx_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
model_instance_ctx_mtx_
);
auto
&
model_context
=
model_contexts_
[
model
];
model_context
.
RequestRemoval
();
for
(
const
auto
&
instance
:
model_instance_ctxs_
[
model
])
{
instance
->
WaitForRemoval
();
if
(
!
ignore_resources_and_priority_
)
{
resource_manager_
->
RemoveModelInstance
(
instance
.
get
());
}
}
model_instance_ctxs_
.
erase
(
model
);
model_contexts_
.
erase
(
model
);
}
if
(
!
ignore_resources_and_priority_
)
{
RETURN_IF_ERROR
(
resource_manager_
->
UpdateResourceLimits
());
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queues_mu_
);
if
(
payload_queues_
.
find
(
model
)
!=
payload_queues_
.
end
())
{
payload_queues_
.
erase
(
model
);
}
}
return
Status
::
Success
;
}
bool
RateLimiter
::
PayloadSlotAvailable
(
const
TritonModel
*
model
)
{
bool
result
;
PayloadQueue
*
payload_queue
=
payload_queues_
[
model
].
get
();
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
result
=
payload_queue
->
queue_
->
Size
()
<
2
*
payload_queue
->
specific_queues_
.
size
();
}
return
result
;
}
Status
RateLimiter
::
EnqueuePayload
(
const
TritonModel
*
model
,
std
::
shared_ptr
<
Payload
>
payload
)
{
auto
pinstance
=
payload
->
GetInstance
();
if
(
payload_queues_
.
find
(
model
)
==
payload_queues_
.
end
())
{
LOG_INFO
<<
"Should not print this "
;
}
PayloadQueue
*
payload_queue
=
payload_queues_
[
model
].
get
();
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
payload
->
SetState
(
Payload
::
State
::
REQUESTED
);
if
(
ignore_resources_and_priority_
)
{
SchedulePayload
(
pinstance
,
payload_queue
,
payload
);
}
}
if
(
ignore_resources_and_priority_
)
{
if
(
pinstance
==
nullptr
)
{
payload_queue
->
cv_
.
notify_one
();
}
else
{
payload_queue
->
cv_
.
notify_all
();
}
}
else
{
StandardScheduleFunc
sched_func
=
[
this
,
payload_queue
,
payload
](
ModelInstanceContext
*
mi
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
this
->
SchedulePayload
(
mi
->
RawInstance
(),
payload_queue
,
payload
);
}
auto
cb
=
[
mi
]()
{
mi
->
Release
();
};
payload
->
AddInternalReleaseCallback
(
cb
);
if
(
mi
->
RawInstance
()
==
nullptr
)
{
payload_queue
->
cv_
.
notify_one
();
}
else
{
payload_queue
->
cv_
.
notify_all
();
}
};
DeferPayloadSchedule
(
sched_func
,
model
,
payload
->
GetInstance
());
}
return
Status
::
Success
;
}
void
RateLimiter
::
DequeuePayload
(
std
::
deque
<
TritonModelInstance
*>&
instances
,
std
::
shared_ptr
<
Payload
>*
payload
)
{
payload
->
reset
();
if
(
payload_queues_
.
find
(
instances
[
0
]
->
Model
())
==
payload_queues_
.
end
())
{
LOG_INFO
<<
"Should not print this "
;
}
PayloadQueue
*
payload_queue
=
payload_queues_
[
instances
[
0
]
->
Model
()].
get
();
std
::
vector
<
std
::
shared_ptr
<
Payload
>>
merged_payloads
;
size_t
instance_index
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
payload_queue
->
mu_
);
payload_queue
->
cv_
.
wait
(
lk
,
[
&
instances
,
&
instance_index
,
payload_queue
]()
{
bool
empty
=
payload_queue
->
queue_
->
Empty
();
if
(
empty
)
{
instance_index
=
0
;
for
(
const
auto
instance
:
instances
)
{
empty
=
payload_queue
->
specific_queues_
[
instance
]
->
Empty
();
if
(
empty
)
{
instance_index
++
;
}
else
{
break
;
}
}
}
return
!
empty
;
});
if
(
instance_index
<
instances
.
size
())
{
TritonModelInstance
*
instance
=
instances
[
instance_index
];
if
(
!
payload_queue
->
specific_queues_
[
instance
]
->
Empty
())
{
payload_queue
->
specific_queues_
[
instance
]
->
Dequeue
(
payload
,
&
merged_payloads
);
}
}
else
{
payload_queue
->
queue_
->
Dequeue
(
payload
,
&
merged_payloads
);
}
}
for
(
auto
&
merge_payload
:
merged_payloads
)
{
PayloadRelease
(
merge_payload
);
}
(
*
payload
)
->
Callback
();
if
((
*
payload
)
->
GetInstance
()
==
nullptr
)
{
(
*
payload
)
->
SetInstance
(
instances
.
front
());
instances
.
pop_front
();
}
else
{
instances
.
erase
(
instances
.
begin
()
+
instance_index
);
}
}
std
::
shared_ptr
<
Payload
>
RateLimiter
::
GetPayload
(
const
Payload
::
Operation
op_type
,
TritonModelInstance
*
instance
)
{
std
::
shared_ptr
<
Payload
>
payload
;
if
(
max_payload_bucket_count_
>
0
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
payload_mu_
);
if
(
!
payload_bucket_
.
empty
())
{
payload
=
payload_bucket_
.
back
();
payload_bucket_
.
pop_back
();
}
if
(
payload
.
get
()
==
nullptr
&&
(
!
payloads_in_use_
.
empty
()))
{
// Just checking the front of the queue instead the entire queue for
// an available payload to save time.
if
(
payloads_in_use_
.
front
().
use_count
()
==
1
)
{
payload
=
payloads_in_use_
.
front
();
payloads_in_use_
.
pop_front
();
}
}
}
if
(
payload
.
get
()
==
nullptr
)
{
payload
.
reset
(
new
Payload
());
}
payload
->
Reset
(
op_type
,
instance
);
return
payload
;
}
void
RateLimiter
::
PayloadRelease
(
std
::
shared_ptr
<
Payload
>&
payload
)
{
payload
->
OnRelease
();
if
(
max_payload_bucket_count_
>
0
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
payload_mu_
);
if
(
payloads_in_use_
.
size
()
+
payload_bucket_
.
size
()
<
max_payload_bucket_count_
)
{
// Release iff the payload shared_ptr is uniquely held.
if
(
payload
.
use_count
()
==
1
)
{
payload
->
Release
();
payload_bucket_
.
push_back
(
std
::
move
(
payload
));
return
;
}
else
{
payloads_in_use_
.
push_back
(
std
::
move
(
payload
));
}
}
}
}
RateLimiter
::
RateLimiter
(
const
bool
ignore_resources_and_priority
,
const
ResourceMap
&
resource_map
)
:
ignore_resources_and_priority_
(
ignore_resources_and_priority
),
max_payload_bucket_count_
(
MAX_PAYLOAD_BUCKET_COUNT
)
{
ResourceManager
::
Create
(
resource_map
,
&
resource_manager_
);
}
void
RateLimiter
::
InitializePayloadQueues
(
const
TritonModelInstance
*
instance
)
{
auto
&
config
=
instance
->
Model
()
->
Config
();
uint64_t
max_queue_delay_microseconds
;
if
(
config
.
has_sequence_batching
())
{
const
auto
&
batcher_config
=
config
.
sequence_batching
();
if
(
batcher_config
.
has_oldest
())
{
max_queue_delay_microseconds
=
batcher_config
.
oldest
().
max_queue_delay_microseconds
();
}
else
{
max_queue_delay_microseconds
=
0
;
}
}
else
if
(
config
.
has_dynamic_batching
())
{
max_queue_delay_microseconds
=
config
.
dynamic_batching
().
max_queue_delay_microseconds
();
}
else
{
max_queue_delay_microseconds
=
0
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
payload_queues_mu_
);
if
(
payload_queues_
.
find
(
instance
->
Model
())
==
payload_queues_
.
end
())
{
payload_queues_
.
emplace
(
instance
->
Model
(),
new
PayloadQueue
(
config
.
max_batch_size
(),
max_queue_delay_microseconds
*
1000
));
}
}
PayloadQueue
*
payload_queue
=
payload_queues_
[
instance
->
Model
()].
get
();
if
(
payload_queue
->
specific_queues_
.
find
(
instance
)
==
payload_queue
->
specific_queues_
.
end
())
{
payload_queue
->
specific_queues_
.
emplace
(
instance
,
new
InstanceQueue
(
config
.
max_batch_size
(),
max_queue_delay_microseconds
*
1000
));
}
}
Status
RateLimiter
::
DeferPayloadSchedule
(
const
StandardScheduleFunc
&
OnSchedule
,
const
TritonModel
*
model
,
TritonModelInstance
*
triton_model_instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
model_ctx_mtx_
);
auto
itr
=
model_contexts_
.
find
(
model
);
if
(
itr
==
model_contexts_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Requested model is not yet registered with rate limiter"
);
}
if
(
itr
->
second
.
isRemovalInProgress
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"New model requests can not be made to a model that is being "
"removed"
);
}
itr
->
second
.
EnqueueModelInstanceRequest
(
OnSchedule
,
triton_model_instance
);
itr
->
second
.
StageInstanceIfAvailable
(
triton_model_instance
);
return
Status
::
Success
;
}
void
RateLimiter
::
SchedulePayload
(
TritonModelInstance
*
tmi
,
PayloadQueue
*
payload_queue
,
const
std
::
shared_ptr
<
Payload
>&
payload
)
{
if
(
tmi
==
nullptr
)
{
payload_queue
->
queue_
->
Enqueue
(
payload
);
}
else
{
payload_queue
->
specific_queues_
[
tmi
]
->
Enqueue
(
payload
);
}
payload
->
SetState
(
Payload
::
State
::
SCHEDULED
);
}
void
RateLimiter
::
OnStage
(
ModelInstanceContext
*
instance
)
{
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
staged_instances_mtx_
);
staged_instances_
.
push
(
instance
);
}
AttemptAllocation
();
}
void
RateLimiter
::
OnRelease
(
ModelInstanceContext
*
instance
)
{
auto
&
model_context
=
model_contexts_
[
instance
->
RawInstance
()
->
Model
()];
model_context
.
AddAvailableInstance
(
instance
);
resource_manager_
->
ReleaseResources
(
instance
);
if
(
model_context
.
ContainsPendingRequests
(
instance
->
RawInstance
()
->
Index
()))
{
model_context
.
StageInstanceIfAvailable
(
instance
->
RawInstance
());
}
AttemptAllocation
();
}
void
RateLimiter
::
AttemptAllocation
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
staged_instances_mtx_
);
if
(
!
staged_instances_
.
empty
())
{
ModelInstanceContext
*
instance
=
staged_instances_
.
top
();
if
(
resource_manager_
->
AllocateResources
(
instance
))
{
staged_instances_
.
pop
();
instance
->
Allocate
();
}
}
}
//=========================================================================
// ModelContext Implementation
//=========================================================================
RateLimiter
::
ModelContext
::
ModelContext
()
:
removal_in_progress_
(
false
)
{}
Status
RateLimiter
::
ModelContext
::
EnqueueModelInstanceRequest
(
const
StandardScheduleFunc
&
OnSchedule
,
TritonModelInstance
*
triton_model_instance
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
sched_request_queue_mtx_
);
if
(
triton_model_instance
==
nullptr
)
{
generic_sched_request_queue_
.
push
(
OnSchedule
);
}
else
if
(
(
uint32_t
)
triton_model_instance
->
Index
()
<
specific_sched_request_queues_
.
size
())
{
specific_sched_request_queues_
[
triton_model_instance
->
Index
()].
push
(
OnSchedule
);
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"expected instance index between 0 and "
+
std
::
to_string
(
specific_sched_request_queues_
.
size
())
+
", got "
+
std
::
to_string
(
triton_model_instance
->
Index
()));
}
return
Status
::
Success
;
}
void
RateLimiter
::
ModelContext
::
AddAvailableInstance
(
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
avbl_instances_mtx_
);
avbl_instances_
.
push
(
instance
);
instance
->
MarkAvailable
();
}
void
RateLimiter
::
ModelContext
::
StageInstanceIfAvailable
(
TritonModelInstance
*
req_instance
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk1
(
sched_request_queue_mtx_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lk2
(
avbl_instances_mtx_
);
PriorityQueue
backup_queue
;
while
(
!
avbl_instances_
.
empty
())
{
ModelInstanceContext
*
instance
=
avbl_instances_
.
top
();
if
((
req_instance
!=
nullptr
)
&&
(
instance
->
RawInstance
()
!=
req_instance
))
{
backup_queue
.
push
(
instance
);
avbl_instances_
.
pop
();
continue
;
}
if
(
!
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
empty
())
{
// Prioritize the specific requests for the available model
// instance highest priority.
const
StandardScheduleFunc
func
=
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
front
();
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()].
pop
();
instance
->
Stage
(
func
);
}
else
if
(
!
generic_sched_request_queue_
.
empty
())
{
// If request is for generic model instance then use the
// instance with the highest priority.
const
StandardScheduleFunc
func
=
generic_sched_request_queue_
.
front
();
generic_sched_request_queue_
.
pop
();
instance
->
Stage
(
func
);
}
else
{
// If there are requests for a specific model instance then backup
// the model instance and keep searching through the available
// model instances. The prioritization will be taken care of in the
// staging priority queue.
backup_queue
.
push
(
instance
);
}
avbl_instances_
.
pop
();
}
// Restore the backup queue
if
(
!
backup_queue
.
empty
())
{
avbl_instances_
.
swap
(
backup_queue
);
}
}
void
RateLimiter
::
ModelContext
::
AllocateInstanceIfAvailable
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk1
(
sched_request_queue_mtx_
);
std
::
lock_guard
<
std
::
recursive_mutex
>
lk2
(
avbl_instances_mtx_
);
PriorityQueue
backup_queue
;
while
(
!
avbl_instances_
.
empty
())
{
ModelInstanceContext
*
instance
=
avbl_instances_
.
top
();
if
(
!
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
empty
())
{
// Prioritize the specific requests for the available model
// instance highest priority.
const
StandardScheduleFunc
func
=
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()]
.
front
();
specific_sched_request_queues_
[
instance
->
RawInstance
()
->
Index
()].
pop
();
instance
->
DirectAllocate
(
func
);
}
else
if
(
!
generic_sched_request_queue_
.
empty
())
{
// If request is for generic model instance then use the
// instance with the highest priority.
const
StandardScheduleFunc
func
=
generic_sched_request_queue_
.
front
();
generic_sched_request_queue_
.
pop
();
instance
->
DirectAllocate
(
func
);
}
else
{
// If there are requests for a specific model instance then backup
// the model instance and keep searching through the available
// model instances. The prioritization will be taken care of in the
// staging priority queue.
backup_queue
.
push
(
instance
);
}
avbl_instances_
.
pop
();
}
// Restore the backup queue
if
(
!
backup_queue
.
empty
())
{
avbl_instances_
.
swap
(
backup_queue
);
}
}
void
RateLimiter
::
ModelContext
::
AddSpecificRequestQueue
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
sched_request_queue_mtx_
);
specific_sched_request_queues_
.
emplace_back
();
}
bool
RateLimiter
::
ModelContext
::
ContainsPendingRequests
(
int
index
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
sched_request_queue_mtx_
);
return
(
generic_sched_request_queue_
.
size
()
!=
0
)
||
(
specific_sched_request_queues_
[
index
].
size
()
!=
0
);
}
void
RateLimiter
::
ModelContext
::
RequestRemoval
()
{
removal_in_progress_
=
true
;
}
//=========================================================================
// ModelInstanceContext Implementation
//=========================================================================
RateLimiter
::
ModelInstanceContext
::
ModelInstanceContext
(
TritonModelInstance
*
triton_model_instance
,
RateLimiter
::
ModelContext
*
model_context
,
const
RateLimiter
::
RateLimiterConfig
&
rate_limiter_config
,
RateLimiter
::
StandardStageFunc
OnStage
,
RateLimiter
::
StandardReleaseFunc
OnRelease
)
:
triton_model_instance_
(
triton_model_instance
),
index_
(
triton_model_instance
->
Index
()),
model_context_
(
model_context
),
rate_limiter_config_
(
rate_limiter_config
),
OnStage_
(
OnStage
),
OnRelease_
(
OnRelease
),
exec_count_
(
0
),
state_
(
AVAILABLE
)
{
}
void
RateLimiter
::
ModelInstanceContext
::
MarkAvailable
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
state_
=
AVAILABLE
;
}
Status
RateLimiter
::
ModelInstanceContext
::
Stage
(
StandardScheduleFunc
OnSchedule
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
(
state_
!=
AVAILABLE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not stage a model instance that is not yet available"
);
}
state_
=
STAGED
;
OnSchedule_
=
OnSchedule
;
}
OnStage_
(
this
);
return
Status
::
Success
;
}
Status
RateLimiter
::
ModelInstanceContext
::
Allocate
()
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
(
state_
!=
STAGED
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not allocate a model instance that is not yet staged"
);
}
state_
=
ALLOCATED
;
}
OnSchedule_
(
this
);
return
Status
::
Success
;
}
Status
RateLimiter
::
ModelInstanceContext
::
DirectAllocate
(
StandardScheduleFunc
OnSchedule
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
(
state_
!=
AVAILABLE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not allocate a model instance that is not yet available"
);
}
state_
=
ALLOCATED
;
}
OnSchedule
(
this
);
return
Status
::
Success
;
}
void
RateLimiter
::
ModelInstanceContext
::
Release
()
{
exec_count_
++
;
OnRelease_
(
this
);
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
((
model_context_
->
isRemovalInProgress
())
&&
(
state_
==
AVAILABLE
)
&&
(
!
model_context_
->
ContainsPendingRequests
(
index_
)))
{
state_
=
REMOVED
;
}
}
if
(
state_
==
REMOVED
)
{
cv_
.
notify_all
();
}
}
void
RateLimiter
::
ModelInstanceContext
::
RequestRemoval
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
state_mtx_
);
if
((
state_
==
AVAILABLE
)
&&
(
!
model_context_
->
ContainsPendingRequests
(
index_
)))
{
state_
=
REMOVED
;
}
}
void
RateLimiter
::
ModelInstanceContext
::
WaitForRemoval
()
{
if
(
!
model_context_
->
isRemovalInProgress
())
{
model_context_
->
RequestRemoval
();
}
RequestRemoval
();
// Wait for the instance to be removed
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
state_mtx_
);
cv_
.
wait
(
lk
,
[
this
]
{
return
state_
==
REMOVED
;
});
}
}
double
RateLimiter
::
ModelInstanceContext
::
ScaledPriority
()
{
// TODO: Different schemes for the prioritization of
// model instance can be added here.
// The priority of instance is 1 by default. If specified
// as 0, the priority is still treated as 1.
auto
priority
=
std
::
max
(
rate_limiter_config_
.
priority
(),
1u
);
return
(
exec_count_
*
priority
);
}
//=========================================================================
// ResourceManager Implementation
//=========================================================================
Status
RateLimiter
::
ResourceManager
::
Create
(
const
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
ResourceManager
>*
resource_manager
)
{
std
::
unique_ptr
<
ResourceManager
>
local_resource_manager
(
new
ResourceManager
(
resource_map
));
*
resource_manager
=
std
::
move
(
local_resource_manager
);
return
Status
::
Success
;
}
void
RateLimiter
::
ResourceManager
::
AddModelInstance
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
model_resources_mtx_
);
auto
pr
=
model_resources_
.
emplace
(
std
::
make_pair
(
instance
,
ResourceMap
()));
for
(
const
auto
&
resource
:
instance
->
GetRateLimiterConfig
()
->
resources
())
{
if
(
resource
.
global
())
{
(
pr
.
first
->
second
[
GLOBAL_RESOURCE_KEY
])[
resource
.
name
()]
=
resource
.
count
();
}
else
{
(
pr
.
first
->
second
[
instance
->
RawInstance
()
->
DeviceId
()])[
resource
.
name
()]
=
resource
.
count
();
}
}
}
Status
RateLimiter
::
ResourceManager
::
RemoveModelInstance
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
model_resources_mtx_
);
const
auto
&
itr
=
model_resources_
.
find
(
instance
);
if
(
itr
==
model_resources_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can not find the instance to remove"
);
}
model_resources_
.
erase
(
instance
);
return
Status
::
Success
;
}
Status
RateLimiter
::
ResourceManager
::
UpdateResourceLimits
()
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
max_resources_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
model_resources_mtx_
);
max_resources_
.
clear
();
// Obtain the maximum resource across all the instances
// and use it as the default available.
for
(
const
auto
&
instance_resources
:
model_resources_
)
{
for
(
const
auto
&
resource_device_map
:
instance_resources
.
second
)
{
auto
ditr
=
max_resources_
.
find
(
resource_device_map
.
first
);
if
(
ditr
==
max_resources_
.
end
())
{
ditr
=
max_resources_
.
emplace
(
resource_device_map
.
first
,
resource_device_map
.
second
)
.
first
;
}
else
{
for
(
const
auto
resource
:
resource_device_map
.
second
)
{
auto
ritr
=
ditr
->
second
.
find
(
resource
.
first
);
if
(
ritr
==
ditr
->
second
.
end
())
{
ritr
=
ditr
->
second
.
emplace
(
resource
.
first
,
resource
.
second
).
first
;
}
else
{
if
(
ritr
->
second
<
resource
.
second
)
{
ritr
->
second
=
resource
.
second
;
}
}
}
}
}
}
if
(
!
explicit_max_resources_
.
empty
())
{
RETURN_IF_ERROR
(
ParseAndValidateExplicitResources
());
}
RETURN_IF_ERROR
(
ValidateMaxResources
());
if
(
LOG_VERBOSE_IS_ON
(
1
))
{
std
::
string
resource_map_str
{
"
\n
Max Resource Map===>
\n
"
};
for
(
const
auto
&
ditr
:
max_resources_
)
{
if
(
!
ditr
.
second
.
empty
())
{
std
::
string
device_str
{(
ditr
.
first
==
GLOBAL_RESOURCE_KEY
)
?
"GLOBAL"
:
std
::
to_string
(
ditr
.
first
)};
resource_map_str
+=
"
\t
Device: "
+
device_str
+
"
\n
"
;
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
resource_map_str
+=
"
\t\t
Resource: "
+
ritr
.
first
+
"
\t
Count: "
+
std
::
to_string
(
ritr
.
second
)
+
"
\n
"
;
}
}
}
LOG_VERBOSE
(
1
)
<<
resource_map_str
;
}
return
Status
::
Success
;
}
Status
RateLimiter
::
ResourceManager
::
ValidateMaxResources
()
{
for
(
const
auto
&
global_resource
:
max_resources_
[
GLOBAL_RESOURCE_KEY
])
{
for
(
const
auto
&
ditr
:
max_resources_
)
{
if
(
ditr
.
first
!=
GLOBAL_RESOURCE_KEY
)
{
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
if
(
global_resource
.
first
.
compare
(
ritr
.
first
)
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"Resource
\"
"
)
+
ritr
.
first
+
"
\"
is present as both global and device-specific resource in "
"the model configuration."
)
.
c_str
());
}
}
}
}
}
return
Status
::
Success
;
}
Status
RateLimiter
::
ResourceManager
::
ParseAndValidateExplicitResources
()
{
for
(
auto
&
ditr
:
max_resources_
)
{
for
(
auto
&
ritr
:
ditr
.
second
)
{
// If not specified explicitly, consider the resource to be unavailable.
size_t
resource_count
=
0
;
if
(
ditr
.
first
==
GLOBAL_RESOURCE_KEY
)
{
// Ignore the device specification... will search for all resources in
// the map...
for
(
const
auto
&
exp_ditr
:
explicit_max_resources_
)
{
for
(
const
auto
&
exp_ritr
:
exp_ditr
.
second
)
{
if
(
ritr
.
first
.
compare
(
exp_ritr
.
first
)
==
0
)
{
if
(
resource_count
<
exp_ritr
.
second
)
{
resource_count
=
exp_ritr
.
second
;
}
}
}
}
}
else
{
// Search only for the device specific or per-device resources...
// device-specific
for
(
const
auto
&
exp_ritr
:
explicit_max_resources_
[
ditr
.
first
])
{
if
(
ritr
.
first
.
compare
(
exp_ritr
.
first
)
==
0
)
{
if
(
resource_count
<
exp_ritr
.
second
)
{
resource_count
=
exp_ritr
.
second
;
}
}
}
// per-device
for
(
const
auto
&
exp_ritr
:
explicit_max_resources_
[
PER_DEVICE_RESOURCE_KEY
])
{
if
(
ritr
.
first
.
compare
(
exp_ritr
.
first
)
==
0
)
{
if
(
resource_count
<
exp_ritr
.
second
)
{
resource_count
=
exp_ritr
.
second
;
}
}
}
}
if
(
resource_count
<
ritr
.
second
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"Resource count for
\"
"
)
+
ritr
.
first
+
"
\"
is limited to "
+
std
::
to_string
(
resource_count
)
+
" which will prevent scheduling of one or more model "
"instances, the minimum required count is "
+
std
::
to_string
(
ritr
.
second
))
.
c_str
());
}
else
{
ritr
.
second
=
resource_count
;
}
}
}
return
Status
::
Success
;
}
bool
RateLimiter
::
ResourceManager
::
AllocateResources
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_resources_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
allocated_resources_mtx_
);
const
auto
&
itr
=
model_resources_
.
find
(
instance
);
if
(
itr
==
model_resources_
.
end
())
{
return
false
;
}
else
{
// First pass to verify if resources are available
{
std
::
lock_guard
<
std
::
mutex
>
lk3
(
max_resources_mtx_
);
for
(
const
auto
&
ditr
:
itr
->
second
)
{
auto
allocated_ditr
=
allocated_resources_
.
find
(
ditr
.
first
);
if
(
allocated_ditr
==
allocated_resources_
.
end
())
{
allocated_ditr
=
allocated_resources_
.
emplace
(
ditr
.
first
,
std
::
map
<
std
::
string
,
size_t
>
())
.
first
;
}
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
auto
allocated_ritr
=
allocated_ditr
->
second
.
find
(
ritr
.
first
);
if
(
allocated_ritr
==
allocated_ditr
->
second
.
end
())
{
allocated_ritr
=
allocated_ditr
->
second
.
emplace
(
ritr
.
first
,
0
).
first
;
}
if
((
allocated_ritr
->
second
+
ritr
.
second
)
>
(
max_resources_
[
ditr
.
first
])[
ritr
.
first
])
{
return
false
;
}
}
}
}
// Second pass to actually allocate the resources
for
(
const
auto
&
ditr
:
itr
->
second
)
{
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
(
allocated_resources_
[
ditr
.
first
])[
ritr
.
first
]
+=
ritr
.
second
;
}
}
}
return
true
;
}
Status
RateLimiter
::
ResourceManager
::
ReleaseResources
(
const
ModelInstanceContext
*
instance
)
{
std
::
lock_guard
<
std
::
mutex
>
lk1
(
model_resources_mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lk2
(
allocated_resources_mtx_
);
const
auto
&
itr
=
model_resources_
.
find
(
instance
);
if
(
itr
==
model_resources_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unable find the instance resources to release"
);
}
else
{
for
(
const
auto
&
ditr
:
itr
->
second
)
{
for
(
const
auto
&
ritr
:
ditr
.
second
)
{
(
allocated_resources_
[
ditr
.
first
])[
ritr
.
first
]
-=
ritr
.
second
;
}
}
}
return
Status
::
Success
;
}
RateLimiter
::
ResourceManager
::
ResourceManager
(
const
ResourceMap
&
resource_map
)
:
explicit_max_resources_
(
resource_map
)
{
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/rate_limiter.h
0 → 100644
View file @
0a21fff9
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <vector>
#include "backend_model.h"
#include "backend_model_instance.h"
#include "instance_queue.h"
#include "model_config.pb.h"
#include "payload.h"
#include "status.h"
namespace
triton
{
namespace
core
{
// Limits the rate at which requests are dispatched to the model instances
class
RateLimiter
{
public:
using
RateLimiterConfig
=
inference
::
ModelRateLimiter
;
using
ResourceMap
=
std
::
map
<
int
,
std
::
map
<
std
::
string
,
size_t
>>
;
enum
RESOURCE_KIND_KEY
{
// Key for holding global resources
GLOBAL_RESOURCE_KEY
=
-
2
,
// Key for holding resources per each device
PER_DEVICE_RESOURCE_KEY
=
-
1
};
/// Creates a rate limiter object which will funnel the requests to
/// the model instances. A typical lifetime of the model instance within
/// RateLimiter transition from available -> staged -> allocated -> available.
/// The transition from available to staged occurs when a request is
/// registered for the model. Depending upon the resource availabilty and
/// priority, the RateLimiter will transition an instance to allocated state
/// at some point in the future. The staged state is skipped when
/// configured to ignore the resource constraints. The cycle in this case
/// will be available -> allocated -> available.
/// \param ignore_resources_and_priority Whether or not to ignore resource
/// constraints and cross-model priority. An available instance is directly
/// allocated when true.
/// \param resource_map The map to the available resource count provided
/// explicitly.
/// \return Status object indicating success or failure.
static
Status
Create
(
const
bool
ignore_resources_and_priority
,
const
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
RateLimiter
>*
rate_limiter
);
/// Registers the model instance with the rate limiter.
/// \param instance The pointer to the TritonModelInstance object to register
/// with the rate limiter.
/// \param rate_limiter_config The rate limiter configuration associated with
/// the model instance.
/// \return Status object indicating success or failure.
Status
RegisterModelInstance
(
TritonModelInstance
*
instance
,
const
RateLimiterConfig
&
rate_limiter_config
);
/// Remove model from the set of models being managed by the rate limiter.
/// \param model The pointer to TritonModel object to be removed.
/// \return Status object indicating success or failure.
Status
UnregisterModel
(
const
TritonModel
*
model
);
/// Returns true if there is a payload slot available for the given model.
/// \param model The pointer to TritonModel object to be removed.
/// \return slot availability in boolean.
bool
PayloadSlotAvailable
(
const
TritonModel
*
model
);
/// Enqueues the payload to rate limiter for scheduling on the given model.
/// \param model The pointer to TritonModel object to be removed.
/// \param payload The shared pointer to the payload object.
/// \return Status object indicating success or failure.
Status
EnqueuePayload
(
const
TritonModel
*
model
,
std
::
shared_ptr
<
Payload
>
payload
);
/// Returns the payload that has been scheduled for the given set of model
/// instances. Note that this call is blocking and depends upon the
/// availability of payloads in the rate limiter for the triton model
/// instance.
/// \param instance The pointers to TritonModelInstance objects whose
/// payload is being requested.
/// \param payload The shared pointer to the payload object.
void
DequeuePayload
(
std
::
deque
<
TritonModelInstance
*>&
instance
,
std
::
shared_ptr
<
Payload
>*
payload
);
/// Returns a new payload object.
/// \param op_type The operation type for the payload.
/// \param instance Optional field that providess the model instance that must
/// be used for the execution of the payload. Default is nullptr which allows
/// any model instance to execute the payload.
/// \return The shared pointer to a new payload object.
std
::
shared_ptr
<
Payload
>
GetPayload
(
const
Payload
::
Operation
op_type
,
TritonModelInstance
*
instance
=
nullptr
);
/// Releases the given payload object back to the rate limiter.
/// \param payload The payload to release.
void
PayloadRelease
(
std
::
shared_ptr
<
Payload
>&
payload
);
private:
class
ModelInstanceContext
;
class
ModelContext
;
struct
PayloadQueue
;
using
StandardReleaseFunc
=
std
::
function
<
void
(
ModelInstanceContext
*
)
>
;
using
StandardScheduleFunc
=
std
::
function
<
void
(
ModelInstanceContext
*
)
>
;
using
StandardStageFunc
=
std
::
function
<
void
(
ModelInstanceContext
*
)
>
;
// Holds the state of the model instance.
class
ModelInstanceContext
{
public:
friend
class
RateLimiter
;
friend
class
ResourceManager
;
enum
State
{
AVAILABLE
,
STAGED
,
ALLOCATED
,
REMOVED
};
void
Release
();
TritonModelInstance
*
RawInstance
()
const
{
return
triton_model_instance_
;
}
private:
ModelInstanceContext
(
TritonModelInstance
*
triton_model_instance
,
ModelContext
*
model_context
,
const
RateLimiterConfig
&
rate_limiter_config
,
StandardStageFunc
OnStage
,
StandardReleaseFunc
OnRelease
);
const
RateLimiterConfig
*
GetRateLimiterConfig
()
const
{
return
&
rate_limiter_config_
;
}
void
MarkAvailable
();
double
ScaledPriority
();
Status
Stage
(
StandardScheduleFunc
OnSchedule
);
Status
Allocate
();
Status
DirectAllocate
(
StandardScheduleFunc
OnSchedule
);
void
RequestRemoval
();
void
WaitForRemoval
();
TritonModelInstance
*
triton_model_instance_
;
size_t
index_
;
ModelContext
*
model_context_
;
RateLimiterConfig
rate_limiter_config_
;
StandardStageFunc
OnStage_
;
StandardReleaseFunc
OnRelease_
;
std
::
atomic
<
uint64_t
>
exec_count_
;
State
state_
;
bool
removal_in_progress_
;
std
::
mutex
state_mtx_
;
StandardScheduleFunc
OnSchedule_
;
std
::
condition_variable
cv_
;
};
class
ScaledPriorityComparator
{
public:
bool
operator
()(
ModelInstanceContext
*
a
,
ModelInstanceContext
*
b
)
{
return
a
->
ScaledPriority
()
>
b
->
ScaledPriority
();
}
};
using
PriorityQueue
=
std
::
priority_queue
<
ModelInstanceContext
*
,
std
::
vector
<
ModelInstanceContext
*>
,
ScaledPriorityComparator
>
;
// Holds the active context to a model
class
ModelContext
{
public:
ModelContext
();
Status
EnqueueModelInstanceRequest
(
const
StandardScheduleFunc
&
OnSchedule
,
TritonModelInstance
*
triton_model_instance
);
void
AddAvailableInstance
(
ModelInstanceContext
*
instance
);
void
StageInstanceIfAvailable
(
TritonModelInstance
*
triton_model_instance
);
void
AllocateInstanceIfAvailable
();
void
AddSpecificRequestQueue
();
bool
ContainsPendingRequests
(
int32_t
index
);
void
RequestRemoval
();
bool
isRemovalInProgress
()
{
return
removal_in_progress_
;
}
private:
bool
removal_in_progress_
;
// Queue holding pending scheduling request
std
::
queue
<
StandardScheduleFunc
>
generic_sched_request_queue_
;
std
::
vector
<
std
::
queue
<
StandardScheduleFunc
>>
specific_sched_request_queues_
;
std
::
recursive_mutex
sched_request_queue_mtx_
;
// The set of instances that are available at the moment
PriorityQueue
avbl_instances_
;
std
::
recursive_mutex
avbl_instances_mtx_
;
};
// Manages and keep track of resource allocation to the model instances.
class
ResourceManager
{
public:
static
Status
Create
(
const
ResourceMap
&
resource_map
,
std
::
unique_ptr
<
ResourceManager
>*
resource_manager
);
void
AddModelInstance
(
const
ModelInstanceContext
*
instance
);
Status
RemoveModelInstance
(
const
ModelInstanceContext
*
instance
);
Status
UpdateResourceLimits
();
bool
AllocateResources
(
const
ModelInstanceContext
*
instance
);
Status
ReleaseResources
(
const
ModelInstanceContext
*
instance
);
private:
ResourceManager
(
const
ResourceMap
&
resource_map
);
Status
ValidateMaxResources
();
Status
ParseAndValidateExplicitResources
();
ResourceMap
explicit_max_resources_
;
std
::
map
<
const
ModelInstanceContext
*
,
ResourceMap
>
model_resources_
;
std
::
mutex
model_resources_mtx_
;
ResourceMap
max_resources_
;
std
::
mutex
max_resources_mtx_
;
ResourceMap
allocated_resources_
;
std
::
mutex
allocated_resources_mtx_
;
};
RateLimiter
(
const
bool
ignore_resources_and_priority
,
const
ResourceMap
&
resource_map
);
void
InitializePayloadQueues
(
const
TritonModelInstance
*
instance
);
Status
DeferPayloadSchedule
(
const
StandardScheduleFunc
&
OnSchedule
,
const
TritonModel
*
model
,
TritonModelInstance
*
instance
=
nullptr
);
void
OnStage
(
ModelInstanceContext
*
instance_ptr
);
void
OnRelease
(
ModelInstanceContext
*
instance_ptr
);
void
AttemptAllocation
();
void
SchedulePayload
(
TritonModelInstance
*
tmi
,
PayloadQueue
*
payload_queue
,
const
std
::
shared_ptr
<
Payload
>&
payload
);
bool
ignore_resources_and_priority_
;
// Instance context for the models
std
::
map
<
const
TritonModel
*
,
std
::
vector
<
std
::
shared_ptr
<
ModelInstanceContext
>>>
model_instance_ctxs_
;
std
::
mutex
model_instance_ctx_mtx_
;
// Running context of the models
std
::
map
<
const
TritonModel
*
,
ModelContext
>
model_contexts_
;
std
::
mutex
model_ctx_mtx_
;
// Holds the model instances that have been staged
PriorityQueue
staged_instances_
;
std
::
recursive_mutex
staged_instances_mtx_
;
// Manager to keep track of the resource allocations
std
::
unique_ptr
<
ResourceManager
>
resource_manager_
;
// Mutex to serialize Payload [de]allocation
std
::
mutex
payload_mu_
;
// Mutex to serialize Payload Queues deallocation
std
::
mutex
payload_queues_mu_
;
// Keep some number of Payload objects for reuse to avoid the overhead
// of creating a Payload for every new request.
const
size_t
max_payload_bucket_count_
;
std
::
vector
<
std
::
shared_ptr
<
Payload
>>
payload_bucket_
;
std
::
deque
<
std
::
shared_ptr
<
Payload
>>
payloads_in_use_
;
struct
PayloadQueue
{
explicit
PayloadQueue
(
size_t
max_batch_size
,
uint64_t
max_queue_delay_ns
)
{
queue_
.
reset
(
new
InstanceQueue
(
max_batch_size
,
max_queue_delay_ns
));
}
std
::
unique_ptr
<
InstanceQueue
>
queue_
;
std
::
map
<
const
TritonModelInstance
*
,
std
::
unique_ptr
<
InstanceQueue
>>
specific_queues_
;
std
::
mutex
mu_
;
std
::
condition_variable
cv_
;
};
std
::
map
<
const
TritonModel
*
,
std
::
unique_ptr
<
PayloadQueue
>>
payload_queues_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/repo_agent.cc
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "repo_agent.h"
#include <string>
#include "filesystem.h"
#include "shared_library.h"
#include "triton/common/logging.h"
#include "tritonserver_apis.h"
// For unknown reason, windows will not export the TRITONREPOAGENT_*
// functions declared with dllexport in tritonrepoagent.h. To get
// those functions exported it is (also?) necessary to mark the
// definitions in this file with dllexport as well.
#if defined(_MSC_VER)
#define TRITONAPI_DECLSPEC __declspec(dllexport)
#elif defined(__GNUC__)
#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default")))
#else
#define TRITONAPI_DECLSPEC
#endif
namespace
triton
{
namespace
core
{
std
::
string
TritonRepoAgentLibraryName
(
const
std
::
string
&
agent_name
)
{
#ifdef _WIN32
return
std
::
string
(
"tritonrepoagent_"
)
+
agent_name
+
".dll"
;
#else
return
std
::
string
(
"libtritonrepoagent_"
)
+
agent_name
+
".so"
;
#endif
}
std
::
string
TRITONREPOAGENT_ActionTypeString
(
const
TRITONREPOAGENT_ActionType
type
)
{
switch
(
type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
return
"TRITONREPOAGENT_ACTION_LOAD"
;
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
return
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE"
;
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
return
"TRITONREPOAGENT_ACTION_LOAD_FAIL"
;
case
TRITONREPOAGENT_ACTION_UNLOAD
:
return
"TRITONREPOAGENT_ACTION_UNLOAD"
;
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
return
"TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"
;
}
return
"Unknown TRITONREPOAGENT_ActionType"
;
}
std
::
string
TRITONREPOAGENT_ArtifactTypeString
(
const
TRITONREPOAGENT_ArtifactType
type
)
{
switch
(
type
)
{
case
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
:
return
"TRITONREPOAGENT_ARTIFACT_FILESYSTEM"
;
case
TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM
:
return
"TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM"
;
}
return
"Unknown TRITONREPOAGENT_ArtifactType"
;
}
//
// TritonRepoAgent
//
Status
TritonRepoAgent
::
Create
(
const
std
::
string
&
name
,
const
std
::
string
&
libpath
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
)
{
std
::
shared_ptr
<
TritonRepoAgent
>
lagent
(
new
TritonRepoAgent
(
name
));
{
std
::
unique_ptr
<
SharedLibrary
>
slib
;
RETURN_IF_ERROR
(
SharedLibrary
::
Acquire
(
&
slib
));
RETURN_IF_ERROR
(
slib
->
OpenLibraryHandle
(
libpath
,
&
lagent
->
dlhandle_
));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_Initialize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
init_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_Finalize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
fini_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_ModelInitialize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
model_init_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_ModelFinalize"
,
true
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
model_fini_fn_
)));
RETURN_IF_ERROR
(
slib
->
GetEntrypoint
(
lagent
->
dlhandle_
,
"TRITONREPOAGENT_ModelAction"
,
false
/* optional */
,
reinterpret_cast
<
void
**>
(
&
lagent
->
model_action_fn_
)));
}
// Initialize if needed
if
(
lagent
->
init_fn_
!=
nullptr
)
{
RETURN_IF_TRITONSERVER_ERROR
(
lagent
->
init_fn_
(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
lagent
.
get
())));
}
*
agent
=
std
::
move
(
lagent
);
return
Status
::
Success
;
}
TritonRepoAgent
::~
TritonRepoAgent
()
{
// Finalize if needed
if
(
fini_fn_
!=
nullptr
)
{
auto
err
=
fini_fn_
(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
this
));
if
(
err
!=
nullptr
)
{
LOG_ERROR
<<
"~TritonRepoAgent: "
<<
Status
(
TritonCodeToStatusCode
(
TRITONSERVER_ErrorCode
(
err
)),
TRITONSERVER_ErrorMessage
(
err
))
.
AsString
();
TRITONSERVER_ErrorDelete
(
err
);
};
}
{
std
::
unique_ptr
<
SharedLibrary
>
slib
;
LOG_STATUS_ERROR
(
SharedLibrary
::
Acquire
(
&
slib
),
"~TritonRepoAgent"
);
LOG_STATUS_ERROR
(
slib
->
CloseLibraryHandle
(
dlhandle_
),
"~TritonRepoAgent"
);
}
}
//
// TritonRepoAgentModel
//
Status
TritonRepoAgentModel
::
Create
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
,
const
inference
::
ModelConfig
&
config
,
const
std
::
shared_ptr
<
TritonRepoAgent
>&
agent
,
const
TritonRepoAgent
::
Parameters
&
agent_parameters
,
std
::
unique_ptr
<
TritonRepoAgentModel
>*
agent_model
)
{
std
::
unique_ptr
<
TritonRepoAgentModel
>
lagent_model
(
new
TritonRepoAgentModel
(
type
,
location
,
config
,
agent
,
agent_parameters
));
if
(
agent
->
AgentModelInitFn
()
!=
nullptr
)
{
RETURN_IF_TRITONSERVER_ERROR
(
agent
->
AgentModelInitFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
lagent_model
.
get
())));
}
*
agent_model
=
std
::
move
(
lagent_model
);
return
Status
::
Success
;
}
TritonRepoAgentModel
::~
TritonRepoAgentModel
()
{
// Need to ensure the proper lifecycle is informed
if
(
action_type_set_
)
{
switch
(
current_action_type_
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_LOAD_FAIL
),
"Inform TRITONREPOAGENT_ACTION_LOAD_FAIL"
);
break
;
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_UNLOAD
),
"Inform TRITONREPOAGENT_ACTION_UNLOAD"
);
// Fallthough is not yet an language feature until C++17
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
),
"Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"
);
break
;
case
TRITONREPOAGENT_ACTION_UNLOAD
:
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
),
"Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE"
);
break
;
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
break
;
}
}
if
(
agent_
->
AgentModelFiniFn
()
!=
nullptr
)
{
LOG_TRITONSERVER_ERROR
(
agent_
->
AgentModelFiniFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
)),
"~TritonRepoAgentModel"
);
}
if
(
!
acquired_location_
.
empty
())
{
DeleteMutableLocation
();
}
}
Status
TritonRepoAgentModel
::
InvokeAgent
(
const
TRITONREPOAGENT_ActionType
action_type
)
{
if
((
!
action_type_set_
)
&&
(
action_type
!=
TRITONREPOAGENT_ACTION_LOAD
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle start state "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
if
(
action_type_set_
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_LOAD
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
case
TRITONREPOAGENT_ACTION_UNLOAD
:
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_UNLOAD
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unexpected lifecycle state transition from "
+
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
+
" to "
+
TRITONREPOAGENT_ActionTypeString
(
action_type
));
}
break
;
}
current_action_type_
=
action_type
;
action_type_set_
=
true
;
RETURN_IF_TRITONSERVER_ERROR
(
agent_
->
AgentModelActionFn
()(
reinterpret_cast
<
TRITONREPOAGENT_Agent
*>
(
agent_
.
get
()),
reinterpret_cast
<
TRITONREPOAGENT_AgentModel
*>
(
this
),
action_type
));
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
SetLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
)
{
if
(
current_action_type_
!=
TRITONREPOAGENT_ACTION_LOAD
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"location can only be updated during TRITONREPOAGENT_ACTION_LOAD, "
"current action type is "
+
(
action_type_set_
?
TRITONREPOAGENT_ActionTypeString
(
current_action_type_
)
:
"not set"
));
}
type_
=
type
;
location_
=
location
;
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
Location
(
TRITONREPOAGENT_ArtifactType
*
type
,
const
char
**
location
)
{
if
(
location_
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Model repository location is not set"
);
}
*
type
=
type_
;
*
location
=
location_
.
c_str
();
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
AcquireMutableLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
char
**
location
)
{
if
(
type
!=
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Unexpected artifact type, expects "
"'TRITONREPOAGENT_ARTIFACT_FILESYSTEM'"
);
}
if
(
acquired_location_
.
empty
())
{
std
::
string
lacquired_location
;
RETURN_IF_ERROR
(
MakeTemporaryDirectory
(
FileSystemType
::
LOCAL
,
&
lacquired_location
));
acquired_location_
.
swap
(
lacquired_location
);
acquired_type_
=
type
;
}
*
location
=
acquired_location_
.
c_str
();
return
Status
::
Success
;
}
Status
TritonRepoAgentModel
::
DeleteMutableLocation
()
{
if
(
acquired_location_
.
empty
())
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"No mutable location to be deleted"
);
}
auto
status
=
DeletePath
(
acquired_location_
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to delete previously acquired location '"
<<
acquired_location_
<<
"': "
<<
status
.
AsString
();
}
acquired_location_
.
clear
();
return
Status
::
Success
;
}
//
// TritonRepoAgentManager
//
TritonRepoAgentManager
&
TritonRepoAgentManager
::
Singleton
()
{
static
TritonRepoAgentManager
triton_repo_agent_manager
;
return
triton_repo_agent_manager
;
}
Status
TritonRepoAgentManager
::
SetGlobalSearchPath
(
const
std
::
string
&
path
)
{
auto
&
singleton_manager
=
Singleton
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton_manager
.
mu_
);
singleton_manager
.
global_search_path_
=
path
;
return
Status
::
Success
;
}
Status
TritonRepoAgentManager
::
CreateAgent
(
const
std
::
string
&
agent_name
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
)
{
auto
&
singleton_manager
=
Singleton
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton_manager
.
mu_
);
// Get the path to the agent shared library. Search path is global
// agent directory. FIXME expose global path as Triton option
const
std
::
vector
<
std
::
string
>
search_paths
=
{
JoinPath
({
singleton_manager
.
global_search_path_
,
agent_name
})};
std
::
string
agent_libname
=
TritonRepoAgentLibraryName
(
agent_name
);
std
::
string
libpath
;
for
(
const
auto
&
path
:
search_paths
)
{
const
auto
full_path
=
JoinPath
({
path
,
agent_libname
});
bool
exists
=
false
;
RETURN_IF_ERROR
(
FileExists
(
full_path
,
&
exists
));
if
(
exists
)
{
libpath
=
full_path
;
break
;
}
}
if
(
libpath
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unable to find '"
+
agent_libname
+
"' for repo agent '"
+
agent_name
+
"', searched: "
+
singleton_manager
.
global_search_path_
);
}
const
auto
&
itr
=
singleton_manager
.
agent_map_
.
find
(
libpath
);
if
(
itr
!=
singleton_manager
.
agent_map_
.
end
())
{
// Found in map. If the weak_ptr is still valid that means that
// there are other models using the agent and we just reuse that
// same agent. If the weak_ptr is not valid then agent has been
// unloaded so we need to remove the weak_ptr from the map and
// create the agent again.
*
agent
=
itr
->
second
.
lock
();
if
(
*
agent
!=
nullptr
)
{
return
Status
::
Success
;
}
singleton_manager
.
agent_map_
.
erase
(
itr
);
}
RETURN_IF_ERROR
(
TritonRepoAgent
::
Create
(
agent_name
,
libpath
,
agent
));
singleton_manager
.
agent_map_
.
insert
({
libpath
,
*
agent
});
return
Status
::
Success
;
}
Status
TritonRepoAgentManager
::
AgentState
(
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>*
agent_state
)
{
auto
&
singleton_manager
=
Singleton
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton_manager
.
mu_
);
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>
agent_state_map
(
new
std
::
unordered_map
<
std
::
string
,
std
::
string
>
);
for
(
const
auto
&
agent_pair
:
singleton_manager
.
agent_map_
)
{
auto
&
libpath
=
agent_pair
.
first
;
auto
agent
=
agent_pair
.
second
.
lock
();
if
(
agent
!=
nullptr
)
{
agent_state_map
->
insert
({
agent
->
Name
(),
libpath
});
}
}
*
agent_state
=
std
::
move
(
agent_state_map
);
return
Status
::
Success
;
}
extern
"C"
{
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ApiVersion
(
uint32_t
*
major
,
uint32_t
*
minor
)
{
*
major
=
TRITONREPOAGENT_API_VERSION_MAJOR
;
*
minor
=
TRITONREPOAGENT_API_VERSION_MINOR
;
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryLocation
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
TRITONREPOAGENT_ArtifactType
*
artifact_type
,
const
char
**
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
Location
(
artifact_type
,
location
));
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryLocationAcquire
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ArtifactType
artifact_type
,
const
char
**
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
AcquireMutableLocation
(
artifact_type
,
location
));
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryLocationRelease
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
char
*
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
DeleteMutableLocation
());
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelRepositoryUpdate
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ArtifactType
artifact_type
,
const
char
*
location
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
tam
->
SetLocation
(
artifact_type
,
location
));
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelParameterCount
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
uint32_t
*
count
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
*
count
=
tam
->
AgentParameters
().
size
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelParameter
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
uint32_t
index
,
const
char
**
parameter_name
,
const
char
**
parameter_value
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
const
auto
&
params
=
tam
->
AgentParameters
();
if
(
index
>=
params
.
size
())
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"index out of range for model parameters"
);
}
*
parameter_name
=
params
[
index
].
first
.
c_str
();
*
parameter_value
=
params
[
index
].
second
.
c_str
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelConfig
(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
uint32_t
config_version
,
TRITONSERVER_Message
**
model_config
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
std
::
string
model_config_json
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
ModelConfigToJson
(
tam
->
Config
(),
config_version
,
&
model_config_json
));
return
TRITONSERVER_MessageNewFromSerializedJson
(
model_config
,
model_config_json
.
c_str
(),
model_config_json
.
length
());
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelState
(
TRITONREPOAGENT_AgentModel
*
model
,
void
**
state
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
*
state
=
tam
->
State
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_ModelSetState
(
TRITONREPOAGENT_AgentModel
*
model
,
void
*
state
)
{
TritonRepoAgentModel
*
tam
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
tam
->
SetState
(
state
);
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_State
(
TRITONREPOAGENT_Agent
*
agent
,
void
**
state
)
{
TritonRepoAgent
*
ta
=
reinterpret_cast
<
TritonRepoAgent
*>
(
agent
);
*
state
=
ta
->
State
();
return
nullptr
;
// success
}
TRITONAPI_DECLSPEC
TRITONSERVER_Error
*
TRITONREPOAGENT_SetState
(
TRITONREPOAGENT_Agent
*
agent
,
void
*
state
)
{
TritonRepoAgent
*
ta
=
reinterpret_cast
<
TritonRepoAgent
*>
(
agent
);
ta
->
SetState
(
state
);
return
nullptr
;
// success
}
}
// extern C
}}
// namespace triton::core
3rdparty/core-r22.12/src/repo_agent.h
0 → 100644
View file @
0a21fff9
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "tritonserver_apis.h"
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "constants.h"
#include "model_config_utils.h"
namespace
triton
{
namespace
core
{
std
::
string
TritonRepoAgentLibraryName
(
const
std
::
string
&
agent_name
);
std
::
string
TRITONREPOAGENT_ActionTypeString
(
const
TRITONREPOAGENT_ActionType
type
);
std
::
string
TRITONREPOAGENT_ArtifactTypeString
(
const
TRITONREPOAGENT_ArtifactType
type
);
class
TritonRepoAgent
{
public:
using
Parameters
=
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
;
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentInitFn_t
)(
TRITONREPOAGENT_Agent
*
agent
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentFiniFn_t
)(
TRITONREPOAGENT_Agent
*
agent
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentModelInitFn_t
)(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentModelFiniFn_t
)(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
);
typedef
TRITONSERVER_Error
*
(
*
TritonRepoAgentModelActionFn_t
)(
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ActionType
action_type
);
static
Status
Create
(
const
std
::
string
&
name
,
const
std
::
string
&
libpath
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
);
~
TritonRepoAgent
();
const
std
::
string
&
Name
()
{
return
name_
;
}
void
*
State
()
{
return
state_
;
}
void
SetState
(
void
*
state
)
{
state_
=
state
;
}
TritonRepoAgentModelActionFn_t
AgentModelActionFn
()
const
{
return
model_action_fn_
;
}
TritonRepoAgentModelInitFn_t
AgentModelInitFn
()
const
{
return
model_init_fn_
;
}
TritonRepoAgentModelFiniFn_t
AgentModelFiniFn
()
const
{
return
model_fini_fn_
;
}
protected:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgent
);
TritonRepoAgent
(
const
std
::
string
&
name
)
:
name_
(
name
),
state_
(
nullptr
),
dlhandle_
(
nullptr
),
init_fn_
(
nullptr
),
fini_fn_
(
nullptr
),
model_init_fn_
(
nullptr
),
model_fini_fn_
(
nullptr
),
model_action_fn_
(
nullptr
)
{
}
const
std
::
string
name_
;
void
*
state_
;
// dlopen / dlsym handles
void
*
dlhandle_
;
TritonRepoAgentInitFn_t
init_fn_
;
TritonRepoAgentFiniFn_t
fini_fn_
;
TritonRepoAgentModelInitFn_t
model_init_fn_
;
TritonRepoAgentModelFiniFn_t
model_fini_fn_
;
TritonRepoAgentModelActionFn_t
model_action_fn_
;
};
class
TritonRepoAgentModel
{
public:
static
Status
Create
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
,
const
inference
::
ModelConfig
&
config
,
const
std
::
shared_ptr
<
TritonRepoAgent
>&
agent
,
const
TritonRepoAgent
::
Parameters
&
agent_parameters
,
std
::
unique_ptr
<
TritonRepoAgentModel
>*
agent_model
);
~
TritonRepoAgentModel
();
void
*
State
()
{
return
state_
;
}
void
SetState
(
void
*
state
)
{
state_
=
state
;
}
Status
InvokeAgent
(
const
TRITONREPOAGENT_ActionType
action_type
);
const
TritonRepoAgent
::
Parameters
&
AgentParameters
()
{
return
agent_parameters_
;
}
Status
SetLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
);
Status
Location
(
TRITONREPOAGENT_ArtifactType
*
type
,
const
char
**
location
);
Status
AcquireMutableLocation
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
char
**
location
);
Status
DeleteMutableLocation
();
const
inference
::
ModelConfig
Config
()
{
return
config_
;
}
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentModel
);
TritonRepoAgentModel
(
const
TRITONREPOAGENT_ArtifactType
type
,
const
std
::
string
&
location
,
const
inference
::
ModelConfig
&
config
,
const
std
::
shared_ptr
<
TritonRepoAgent
>&
agent
,
const
TritonRepoAgent
::
Parameters
&
agent_parameters
)
:
state_
(
nullptr
),
config_
(
config
),
agent_
(
agent
),
agent_parameters_
(
agent_parameters
),
type_
(
type
),
location_
(
location
),
action_type_set_
(
false
),
current_action_type_
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
)
{
}
void
*
state_
;
const
inference
::
ModelConfig
config_
;
const
std
::
shared_ptr
<
TritonRepoAgent
>
agent_
;
const
TritonRepoAgent
::
Parameters
agent_parameters_
;
TRITONREPOAGENT_ArtifactType
type_
;
std
::
string
location_
;
TRITONREPOAGENT_ArtifactType
acquired_type_
;
std
::
string
acquired_location_
;
bool
action_type_set_
;
TRITONREPOAGENT_ActionType
current_action_type_
;
};
class
TritonRepoAgentManager
{
public:
static
Status
SetGlobalSearchPath
(
const
std
::
string
&
path
);
static
Status
CreateAgent
(
const
std
::
string
&
agent_name
,
std
::
shared_ptr
<
TritonRepoAgent
>*
agent
);
static
Status
AgentState
(
std
::
unique_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
string
>>*
agent_state
);
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentManager
);
TritonRepoAgentManager
()
:
global_search_path_
(
"/opt/tritonserver/repoagents"
){};
static
TritonRepoAgentManager
&
Singleton
();
std
::
mutex
mu_
;
std
::
string
global_search_path_
;
std
::
unordered_map
<
std
::
string
,
std
::
weak_ptr
<
TritonRepoAgent
>>
agent_map_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/response_allocator.h
0 → 100644
View file @
0a21fff9
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
//
// Implementation for TRITONSERVER_ResponseAllocator.
//
class
ResponseAllocator
{
public:
explicit
ResponseAllocator
(
TRITONSERVER_ResponseAllocatorAllocFn_t
alloc_fn
,
TRITONSERVER_ResponseAllocatorReleaseFn_t
release_fn
,
TRITONSERVER_ResponseAllocatorStartFn_t
start_fn
)
:
alloc_fn_
(
alloc_fn
),
buffer_attributes_fn_
(
nullptr
),
query_fn_
(
nullptr
),
release_fn_
(
release_fn
),
start_fn_
(
start_fn
)
{
}
void
SetQueryFunction
(
TRITONSERVER_ResponseAllocatorQueryFn_t
query_fn
)
{
query_fn_
=
query_fn
;
}
void
SetBufferAttributesFunction
(
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t
buffer_attributes_fn
)
{
buffer_attributes_fn_
=
buffer_attributes_fn
;
}
TRITONSERVER_ResponseAllocatorAllocFn_t
AllocFn
()
const
{
return
alloc_fn_
;
}
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t
BufferAttributesFn
()
const
{
return
buffer_attributes_fn_
;
}
TRITONSERVER_ResponseAllocatorQueryFn_t
QueryFn
()
const
{
return
query_fn_
;
}
TRITONSERVER_ResponseAllocatorReleaseFn_t
ReleaseFn
()
const
{
return
release_fn_
;
}
TRITONSERVER_ResponseAllocatorStartFn_t
StartFn
()
const
{
return
start_fn_
;
}
private:
TRITONSERVER_ResponseAllocatorAllocFn_t
alloc_fn_
;
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t
buffer_attributes_fn_
;
TRITONSERVER_ResponseAllocatorQueryFn_t
query_fn_
;
TRITONSERVER_ResponseAllocatorReleaseFn_t
release_fn_
;
TRITONSERVER_ResponseAllocatorStartFn_t
start_fn_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/response_cache.cc
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "response_cache.h"
#include "infer_stats.h"
#include "triton/common/logging.h"
namespace
{
enum
class
ScopedTimerType
{
INSERTION
,
LOOKUP
};
class
ScopedTimer
{
public:
explicit
ScopedTimer
(
triton
::
core
::
InferenceRequest
&
request
,
uint64_t
&
duration
,
ScopedTimerType
type
)
:
request_
(
request
),
duration_
(
duration
),
type_
(
type
)
{
switch
(
type_
)
{
case
ScopedTimerType
::
LOOKUP
:
request_
.
CaptureCacheLookupStartNs
();
break
;
case
ScopedTimerType
::
INSERTION
:
request_
.
CaptureCacheInsertionStartNs
();
break
;
}
}
~
ScopedTimer
()
{
switch
(
type_
)
{
case
ScopedTimerType
::
LOOKUP
:
request_
.
CaptureCacheLookupEndNs
();
duration_
+=
request_
.
CacheLookupEndNs
()
-
request_
.
CacheLookupStartNs
();
break
;
case
ScopedTimerType
::
INSERTION
:
request_
.
CaptureCacheInsertionEndNs
();
duration_
+=
request_
.
CacheInsertionEndNs
()
-
request_
.
CacheInsertionStartNs
();
break
;
}
}
private:
triton
::
core
::
InferenceRequest
&
request_
;
uint64_t
&
duration_
;
ScopedTimerType
type_
;
};
std
::
string
PointerToString
(
void
*
ptr
)
{
std
::
stringstream
ss
;
ss
<<
ptr
;
return
ss
.
str
();
}
}
// namespace
namespace
triton
{
namespace
core
{
Status
RequestResponseCache
::
Create
(
uint64_t
cache_size
,
std
::
unique_ptr
<
RequestResponseCache
>*
cache
)
{
try
{
cache
->
reset
(
new
RequestResponseCache
(
cache_size
));
}
catch
(
const
std
::
exception
&
ex
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to initialize Response Cache: "
+
std
::
string
(
ex
.
what
()));
}
return
Status
::
Success
;
}
RequestResponseCache
::
RequestResponseCache
(
const
uint64_t
size
)
{
// Allocate buffer
buffer_
=
malloc
(
size
);
// Exit early if buffer allocation failed
if
(
buffer_
==
nullptr
)
{
throw
std
::
runtime_error
(
"failed to allocate buffer"
);
}
// Create cache as managed buffer
managed_buffer_
=
boost
::
interprocess
::
managed_external_buffer
(
boost
::
interprocess
::
create_only_t
{},
buffer_
,
size
);
LOG_INFO
<<
"Response Cache is created at '"
<<
PointerToString
(
buffer_
)
<<
"' with size "
<<
size
;
}
RequestResponseCache
::~
RequestResponseCache
()
{
// Deallocate each chunk from managed buffer
for
(
auto
&
iter
:
cache_
)
{
auto
&
entry
=
iter
.
second
;
for
(
auto
&
output
:
entry
.
outputs_
)
{
if
(
output
.
buffer_
!=
nullptr
)
{
managed_buffer_
.
deallocate
(
output
.
buffer_
);
}
}
}
// Validate we freed all underlying memory managed by cache
if
(
!
managed_buffer_
.
all_memory_deallocated
())
{
// Destructors can't throw exceptions
LOG_ERROR
<<
"failed to free managed cache memory"
;
}
// Free total cache buffer
if
(
buffer_
!=
nullptr
)
{
free
(
buffer_
);
}
}
Status
RequestResponseCache
::
Lookup
(
InferenceResponse
*
const
response
,
InferenceRequest
*
const
request
)
{
// Lock on cache lookup
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
if
(
request
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Cache Lookup passed a nullptr request"
);
}
// Capture start latency now and end latency when timer goes out of scope
ScopedTimer
timer
(
*
request
,
total_lookup_latency_ns_
,
ScopedTimerType
::
LOOKUP
);
// Hash the request and set cache key if it hasn't already been set
if
(
!
request
->
CacheKeyIsSet
())
{
RETURN_IF_ERROR
(
HashAndSet
(
request
));
}
const
uint64_t
key
=
request
->
CacheKey
();
num_lookups_
++
;
LOG_VERBOSE
(
1
)
<<
request
->
LogRequest
()
<<
"Looking up key ["
+
std
::
to_string
(
key
)
+
"] in cache."
;
// Search cache for request hash key
auto
iter
=
cache_
.
find
(
key
);
if
(
iter
==
cache_
.
end
())
{
num_misses_
++
;
LOG_VERBOSE
(
1
)
<<
request
->
LogRequest
()
<<
"MISS for key ["
+
std
::
to_string
(
key
)
+
"] in cache."
;
return
Status
(
Status
::
Code
::
INTERNAL
,
request
->
LogRequest
()
+
"key not found in cache"
);
}
// If find succeeds, it's a cache hit
num_hits_
++
;
LOG_VERBOSE
(
1
)
<<
request
->
LogRequest
()
<<
"HIT for key ["
+
std
::
to_string
(
key
)
+
"] in cache."
;
// Populate passed-in "response" from cache entry
auto
entry
=
iter
->
second
;
// Build InferenceResponse from CacheEntry
RETURN_IF_ERROR
(
BuildInferenceResponse
(
entry
,
response
));
// Update this key to front of LRU list
UpdateLRU
(
iter
);
LOG_VERBOSE
(
1
)
<<
request
->
LogRequest
()
<<
"Using cached response for key ["
+
std
::
to_string
(
key
)
+
"]."
;
return
Status
::
Success
;
}
Status
RequestResponseCache
::
Insert
(
const
InferenceResponse
&
response
,
InferenceRequest
*
const
request
)
{
// Lock on cache insertion
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
if
(
request
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Cache Insert passed a nullptr request"
);
}
// Capture start latency now and end latency when timer goes out of scope
ScopedTimer
timer
(
*
request
,
total_insertion_latency_ns_
,
ScopedTimerType
::
INSERTION
);
// Hash the request and set cache key if it hasn't already been set
if
(
!
request
->
CacheKeyIsSet
())
{
RETURN_IF_ERROR
(
HashAndSet
(
request
));
}
const
uint64_t
key
=
request
->
CacheKey
();
// Exit early if key already exists in cache
auto
iter
=
cache_
.
find
(
key
);
if
(
iter
!=
cache_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
request
->
LogRequest
()
+
"key ["
+
std
::
to_string
(
key
)
+
"] already exists in cache"
);
}
// Construct cache entry from response
auto
entry
=
CacheEntry
();
RETURN_IF_ERROR
(
BuildCacheEntry
(
response
,
&
entry
));
// Insert entry into cache
LOG_VERBOSE
(
1
)
<<
request
->
LogRequest
()
<<
"Inserting key ["
+
std
::
to_string
(
key
)
+
"] into cache."
;
auto
cache_pair
=
cache_
.
insert
({
key
,
entry
});
// Exit early if cache insertion failed
if
(
!
cache_pair
.
second
)
{
LOG_ERROR
<<
request
->
LogRequest
()
<<
"Failed to insert key into map."
;
return
Status
(
Status
::
Code
::
INTERNAL
,
request
->
LogRequest
()
+
"Cache insertion failed"
);
}
// Update LRU with new cache entry
auto
cache_iter
=
cache_pair
.
first
;
UpdateLRU
(
cache_iter
);
return
Status
::
Success
;
}
// LRU
Status
RequestResponseCache
::
Evict
()
{
// Lock on cache eviction
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
// Nothing to evict if cache is empty
if
(
NumEntries
()
==
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Cache is empty, nothing to evict."
);
}
// Least recently used key in back of LRU list
uint64_t
lru_key
=
lru_
.
back
();
LOG_VERBOSE
(
1
)
<<
"Evicting key ["
+
std
::
to_string
(
lru_key
)
+
"] from cache."
;
// Find cache entry for least recently used key
auto
iter
=
cache_
.
find
(
lru_key
);
// Error check if key isn't in cache, but this shouldn't happen in evict
// and probably indicates a bug
if
(
iter
==
cache_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"key ["
+
std
::
to_string
(
lru_key
)
+
"] not found in cache during eviction: this indicates a bug in the "
"code"
);
}
// Get size of cache entry being evicted to update available size
auto
entry
=
iter
->
second
;
// Free managed memory used in cache entry's outputs
for
(
auto
&
output
:
entry
.
outputs_
)
{
// Lock on buffer deallocation
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
buffer_mtx_
);
managed_buffer_
.
deallocate
(
output
.
buffer_
);
}
// Remove LRU entry from cache
cache_
.
erase
(
lru_key
);
// Remove LRU key from LRU list
lru_
.
pop_back
();
// Increment number of evictions
num_evictions_
++
;
return
Status
::
Success
;
}
// Helpers
void
RequestResponseCache
::
UpdateLRU
(
std
::
unordered_map
<
uint64_t
,
CacheEntry
>::
iterator
&
cache_iter
)
{
// Lock on cache update
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
const
auto
&
key
=
cache_iter
->
first
;
auto
&
cache_entry
=
cache_iter
->
second
;
// Remove key from LRU list if it was already in there
auto
lru_iter
=
std
::
find
(
lru_
.
begin
(),
lru_
.
end
(),
key
);
if
(
lru_iter
!=
lru_
.
end
())
{
lru_
.
erase
(
lru_iter
);
}
// Add key to front of LRU list since it's most recently used
lru_
.
push_front
(
key
);
// Set CacheEntry LRU iterator to new LRU key location
cache_entry
.
lru_iter_
=
lru_
.
begin
();
}
Status
RequestResponseCache
::
BuildCacheEntry
(
const
InferenceResponse
&
response
,
CacheEntry
*
const
entry
)
{
// Build cache entry data from response outputs
for
(
const
auto
&
response_output
:
response
.
Outputs
())
{
auto
cache_output
=
Output
();
// Fetch output buffer details
const
void
*
response_buffer
=
nullptr
;
size_t
response_byte_size
=
0
;
TRITONSERVER_MemoryType
response_memory_type
;
int64_t
response_memory_type_id
;
void
*
userp
;
RETURN_IF_ERROR
(
response_output
.
DataBuffer
(
&
response_buffer
,
&
response_byte_size
,
&
response_memory_type
,
&
response_memory_type_id
,
&
userp
));
// TODO: Handle other memory types
if
(
response_memory_type
!=
TRITONSERVER_MEMORY_CPU
&&
response_memory_type
!=
TRITONSERVER_MEMORY_CPU_PINNED
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Only input buffers in CPU memory are allowed in cache currently"
);
}
// Exit early if response buffer from output is invalid
if
(
response_buffer
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Response buffer from output was nullptr"
);
}
// Lock on managed buffer references
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
buffer_mtx_
);
// Exit early if cache entry will be larger than available cache size
if
(
response_byte_size
>
managed_buffer_
.
get_size
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Cache entry is larger than total cache size"
);
}
// If cache doesn't have enough space, evict until enough space available
// NOTE: FreeBytes() doesn't account for allocator overhead so allocation
// may fail even if response_byte_size is less than FreeBytes()
while
(
response_byte_size
>
FreeBytes
())
{
LOG_VERBOSE
(
1
)
<<
"EVICT: Response larger than remaining available "
"memory, attempting to evict from cache."
;
RETURN_IF_ERROR
(
Evict
());
}
// Attempt to allocate buffer until success or eviction from cache fails
while
(
cache_output
.
buffer_
==
nullptr
)
{
// Allocate buffer for response output in cache entry
cache_output
.
buffer_
=
managed_buffer_
.
allocate
(
response_byte_size
,
std
::
nothrow_t
{});
// Attempt to evict if allocation fails
if
(
cache_output
.
buffer_
==
nullptr
)
{
LOG_VERBOSE
(
1
)
<<
"FAILED to allocate buffer in cache. Attempting to "
"evict an entry."
;
// Exit out if Eviction fails
RETURN_IF_ERROR
(
Evict
());
}
}
// Copy data from response buffer to cache entry output buffer
// TODO: Handle other memory types
std
::
memcpy
(
cache_output
.
buffer_
,
response_buffer
,
response_byte_size
);
// Set output metadata
cache_output
.
name_
=
response_output
.
Name
();
cache_output
.
dtype_
=
response_output
.
DType
();
cache_output
.
shape_
=
response_output
.
Shape
();
cache_output
.
buffer_size_
=
static_cast
<
uint64_t
>
(
response_byte_size
);
}
// Add each output to cache entry
entry
->
outputs_
.
push_back
(
cache_output
);
}
return
Status
::
Success
;
}
Status
RequestResponseCache
::
BuildInferenceResponse
(
const
CacheEntry
&
entry
,
InferenceResponse
*
const
response
)
{
if
(
response
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"invalid response ptr passed in"
);
}
// Lock on cache references
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
// Inference response outputs should be empty so we can append to them
if
(
response
->
Outputs
().
size
()
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"InferenceResponse already contains some outputs"
);
}
for
(
auto
&
cache_output
:
entry
.
outputs_
)
{
InferenceResponse
::
Output
*
response_output
=
nullptr
;
RETURN_IF_ERROR
(
response
->
AddOutput
(
cache_output
.
name_
,
cache_output
.
dtype_
,
cache_output
.
shape_
,
&
response_output
));
if
(
response_output
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"InferenceResponse::Output pointer as nullptr"
);
}
TRITONSERVER_MemoryType
memory_type
=
TRITONSERVER_MEMORY_CPU
;
int64_t
memory_type_id
=
0
;
// Allocate buffer for inference response
void
*
buffer
;
RETURN_IF_ERROR
(
response_output
->
AllocateDataBuffer
(
&
buffer
,
cache_output
.
buffer_size_
,
&
memory_type
,
&
memory_type_id
));
// TODO: Handle other memory types
if
(
memory_type
!=
TRITONSERVER_MEMORY_CPU
&&
memory_type
!=
TRITONSERVER_MEMORY_CPU_PINNED
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Only input buffers in CPU memory are allowed in cache currently"
);
}
if
(
buffer
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate buffer for output '"
+
cache_output
.
name_
+
"'"
);
}
// Copy cached output buffer to allocated response output buffer
std
::
memcpy
(
buffer
,
cache_output
.
buffer_
,
cache_output
.
buffer_size_
);
// TODO: Add field to InferenceResponse to indicate this was from cache
// response.cached = true;
}
}
return
Status
::
Success
;
}
Status
RequestResponseCache
::
HashInputBuffers
(
const
InferenceRequest
::
Input
*
input
,
size_t
*
seed
)
{
// Iterate over each data buffer in input in case of non-contiguous memory
for
(
size_t
idx
=
0
;
idx
<
input
->
DataBufferCount
();
++
idx
)
{
const
void
*
src_buffer
;
size_t
src_byte_size
;
TRITONSERVER_MemoryType
src_memory_type
;
int64_t
src_memory_type_id
;
RETURN_IF_ERROR
(
input
->
DataBuffer
(
idx
,
&
src_buffer
,
&
src_byte_size
,
&
src_memory_type
,
&
src_memory_type_id
));
// TODO: Handle other memory types
if
(
src_memory_type
!=
TRITONSERVER_MEMORY_CPU
&&
src_memory_type
!=
TRITONSERVER_MEMORY_CPU_PINNED
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Only input buffers in CPU memory are allowed in cache currently"
);
}
// Add each byte of input buffer chunk to hash
const
unsigned
char
*
tmp
=
static_cast
<
const
unsigned
char
*>
(
src_buffer
);
for
(
uint64_t
byte
=
0
;
byte
<
src_byte_size
;
byte
++
)
{
boost
::
hash_combine
(
*
seed
,
tmp
[
byte
]);
}
}
return
Status
::
Success
;
}
Status
RequestResponseCache
::
HashInputs
(
const
InferenceRequest
&
request
,
size_t
*
seed
)
{
const
auto
&
inputs
=
request
.
ImmutableInputs
();
// Convert inputs to ordered map for consistency in hashing
// inputs sorted by key (input) name
std
::
map
<
std
::
string
,
InferenceRequest
::
Input
*>
ordered_inputs
(
inputs
.
begin
(),
inputs
.
end
());
for
(
const
auto
&
input
:
ordered_inputs
)
{
// Add input name to hash
boost
::
hash_combine
(
*
seed
,
input
.
second
->
Name
());
// Fetch input buffer for hashing raw data
RETURN_IF_ERROR
(
HashInputBuffers
(
input
.
second
,
seed
));
}
return
Status
::
Success
;
}
Status
RequestResponseCache
::
Hash
(
const
InferenceRequest
&
request
,
uint64_t
*
key
)
{
std
::
size_t
seed
=
0
;
// Add request model name to hash
boost
::
hash_combine
(
seed
,
request
.
ModelName
());
// Add request model version to hash
boost
::
hash_combine
(
seed
,
request
.
ActualModelVersion
());
RETURN_IF_ERROR
(
HashInputs
(
request
,
&
seed
));
*
key
=
static_cast
<
uint64_t
>
(
seed
);
return
Status
::
Success
;
}
Status
RequestResponseCache
::
HashAndSet
(
InferenceRequest
*
const
request
)
{
uint64_t
key
=
0
;
RETURN_IF_ERROR
(
Hash
(
*
request
,
&
key
));
request
->
SetCacheKey
(
key
);
return
Status
::
Success
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/response_cache.h
0 → 100644
View file @
0a21fff9
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <list>
#include <string>
#include <unordered_map>
#include "infer_request.h"
#include "infer_response.h"
#include "model.h"
#include "status.h"
#include <boost/functional/hash.hpp>
#include <boost/interprocess/managed_external_buffer.hpp>
namespace
triton
{
namespace
core
{
// Assuming CPU memory only for now
struct
Output
{
// Output tensor data buffer
void
*
buffer_
;
// Size of "buffer" above
uint64_t
buffer_size_
=
0
;
// Name of the output
std
::
string
name_
;
// Datatype of the output
inference
::
DataType
dtype_
;
// Shape of the output
std
::
vector
<
int64_t
>
shape_
;
};
struct
CacheEntry
{
explicit
CacheEntry
()
{}
// Point to key in LRU list for maintaining LRU order
std
::
list
<
uint64_t
>::
iterator
lru_iter_
;
// each output buffer = managed_buffer.allocate(size, ...)
std
::
vector
<
Output
>
outputs_
;
};
class
RequestResponseCache
{
public:
~
RequestResponseCache
();
// Create the request/response cache object
static
Status
Create
(
uint64_t
cache_size
,
std
::
unique_ptr
<
RequestResponseCache
>*
cache
);
// Hash inference request for cache access and store it in "request" object.
// This will also be called internally in Lookup/Insert if the request hasn't
// already stored it's hash. It is up to the user to update the hash in the
// request if modifying any hashed fields of the request object after storing.
// Return Status object indicating success or failure.
Status
HashAndSet
(
InferenceRequest
*
const
request
);
// Lookup 'request' hash in cache and return the inference response in
// 'response' on cache hit or nullptr on cache miss
// Return Status object indicating success or failure.
Status
Lookup
(
InferenceResponse
*
const
response
,
InferenceRequest
*
const
request
);
// Insert response into cache, evict entries to make space if necessary
// Return Status object indicating success or failure.
Status
Insert
(
const
InferenceResponse
&
response
,
InferenceRequest
*
const
request
);
// Evict entry from cache based on policy
// Return Status object indicating success or failure.
Status
Evict
();
// Returns number of items in cache
size_t
NumEntries
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
cache_
.
size
();
}
// Returns number of items evicted in cache lifespan
size_t
NumEvictions
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
num_evictions_
;
}
// Returns number of lookups in cache lifespan, should sum to hits + misses
size_t
NumLookups
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
num_lookups_
;
}
// Returns number of cache hits in cache lifespan
size_t
NumHits
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
num_hits_
;
}
// Returns number of cache hits in cache lifespan
size_t
NumMisses
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
num_misses_
;
}
// Returns the total lookup latency (nanoseconds) of all lookups in cache
// lifespan
uint64_t
TotalLookupLatencyNs
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
total_lookup_latency_ns_
;
}
uint64_t
TotalInsertionLatencyNs
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
cache_mtx_
);
return
total_insertion_latency_ns_
;
}
// Returns total number of bytes allocated for cache
size_t
TotalBytes
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
buffer_mtx_
);
return
managed_buffer_
.
get_size
();
}
// Returns number of free bytes in cache
size_t
FreeBytes
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
buffer_mtx_
);
return
managed_buffer_
.
get_free_memory
();
}
// Returns number of bytes in use by cache
size_t
AllocatedBytes
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
buffer_mtx_
);
return
managed_buffer_
.
get_size
()
-
managed_buffer_
.
get_free_memory
();
}
// Returns fraction of bytes allocated over total cache size between [0, 1]
double
TotalUtilization
()
{
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
buffer_mtx_
);
return
static_cast
<
double
>
(
AllocatedBytes
())
/
static_cast
<
double
>
(
TotalBytes
());
}
private:
explicit
RequestResponseCache
(
const
uint64_t
cache_size
);
// Update LRU ordering on lookup
void
UpdateLRU
(
std
::
unordered_map
<
uint64_t
,
CacheEntry
>::
iterator
&
);
// Build CacheEntry from InferenceResponse
Status
BuildCacheEntry
(
const
InferenceResponse
&
response
,
CacheEntry
*
const
entry
);
// Build InferenceResponse from CacheEntry
Status
BuildInferenceResponse
(
const
CacheEntry
&
entry
,
InferenceResponse
*
const
response
);
// Helper function to hash data buffers used by "input"
Status
HashInputBuffers
(
const
InferenceRequest
::
Input
*
input
,
size_t
*
seed
);
// Helper function to hash each input in "request"
Status
HashInputs
(
const
InferenceRequest
&
request
,
size_t
*
seed
);
// Helper function to hash request and store it in "key"
Status
Hash
(
const
InferenceRequest
&
request
,
uint64_t
*
key
);
// Cache buffer
void
*
buffer_
;
// Managed buffer
boost
::
interprocess
::
managed_external_buffer
managed_buffer_
;
// key -> CacheEntry containing values and list iterator for LRU management
std
::
unordered_map
<
uint64_t
,
CacheEntry
>
cache_
;
// List of keys sorted from most to least recently used
std
::
list
<
uint64_t
>
lru_
;
// Cache metrics
size_t
num_evictions_
=
0
;
size_t
num_lookups_
=
0
;
size_t
num_hits_
=
0
;
size_t
num_misses_
=
0
;
uint64_t
total_lookup_latency_ns_
=
0
;
uint64_t
total_insertion_latency_ns_
=
0
;
// Mutex for buffer synchronization
std
::
recursive_mutex
buffer_mtx_
;
// Mutex for cache synchronization
std
::
recursive_mutex
cache_mtx_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/scheduler.h
0 → 100644
View file @
0a21fff9
// Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include "infer_request.h"
#include "status.h"
namespace
triton
{
namespace
core
{
// Scheduler interface.
class
Scheduler
{
public:
virtual
~
Scheduler
()
{}
// The prototype for the initialization function that will be called
// by the "standard" schedulers created based on a model's
// scheduling_choice settings. The init function is called once by
// the runner that will later execute requests for 'runner_idx'. A
// non-OK error status indicates an initialization error that
// prevents scheduler from using the runner.
using
StandardInitFunc
=
std
::
function
<
Status
(
uint32_t
runner_idx
)
>
;
// The prototype for the warmup function that will be called by the
// "standard" schedulers created based on a model's
// scheduling_choice settings. The warmup function is called once by
// the runner that will later execute requests for 'runner_idx'. A
// non-OK error status indicates an error that prevents scheduler
// from sending warmup requests to the runner.
using
StandardWarmupFunc
=
std
::
function
<
Status
(
uint32_t
runner_idx
)
>
;
// The prototype for the run function that will be called by the
// "standard" schedulers created based on a model's
// scheduling_choice settings. The run function must accept a
// 'runner_idx' indicating which runner should execute the
// 'requests'. Ownership of the 'requests' is transferred to the
// runner which is responsible for generating responses and
// releasing the requests.
using
StandardRunFunc
=
std
::
function
<
void
(
uint32_t
runner_idx
,
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>&&
requests
)
>
;
// Enqueue a request with the scheduler. If Status::Success is returned
// then the backend has taken ownership of the request object and so
// 'request' will be nullptr. If non-success is returned then the
// caller still retains ownership of 'request'.
virtual
Status
Enqueue
(
std
::
unique_ptr
<
InferenceRequest
>&
request
)
=
0
;
// Return the number of in-flight inferences tracked by the scheduler.
virtual
size_t
InflightInferenceCount
()
=
0
;
// Instruct the scheduler to stop processing future requests unless they are
// considered as in-flight.
virtual
void
Stop
()
=
0
;
};
}}
// namespace triton::core
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment