Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
b30f3cdb
Commit
b30f3cdb
authored
Nov 14, 2023
by
xiabo
Browse files
添加下载的代码
parent
e38ee081
Changes
157
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8354 additions
and
0 deletions
+8354
-0
3rdparty/core-r22.12/src/label_provider.cc
3rdparty/core-r22.12/src/label_provider.cc
+95
-0
3rdparty/core-r22.12/src/label_provider.h
3rdparty/core-r22.12/src/label_provider.h
+65
-0
3rdparty/core-r22.12/src/libtritonserver.ldscript
3rdparty/core-r22.12/src/libtritonserver.ldscript
+32
-0
3rdparty/core-r22.12/src/memory.cc
3rdparty/core-r22.12/src/memory.cc
+238
-0
3rdparty/core-r22.12/src/memory.h
3rdparty/core-r22.12/src/memory.h
+174
-0
3rdparty/core-r22.12/src/metric_family.cc
3rdparty/core-r22.12/src/metric_family.cc
+321
-0
3rdparty/core-r22.12/src/metric_family.h
3rdparty/core-r22.12/src/metric_family.h
+111
-0
3rdparty/core-r22.12/src/metric_model_reporter.cc
3rdparty/core-r22.12/src/metric_model_reporter.cc
+168
-0
3rdparty/core-r22.12/src/metric_model_reporter.h
3rdparty/core-r22.12/src/metric_model_reporter.h
+138
-0
3rdparty/core-r22.12/src/metrics.cc
3rdparty/core-r22.12/src/metrics.cc
+1035
-0
3rdparty/core-r22.12/src/metrics.h
3rdparty/core-r22.12/src/metrics.h
+335
-0
3rdparty/core-r22.12/src/model.cc
3rdparty/core-r22.12/src/model.cc
+137
-0
3rdparty/core-r22.12/src/model.h
3rdparty/core-r22.12/src/model.h
+162
-0
3rdparty/core-r22.12/src/model_config_cuda.cc
3rdparty/core-r22.12/src/model_config_cuda.cc
+61
-0
3rdparty/core-r22.12/src/model_config_cuda.h
3rdparty/core-r22.12/src/model_config_cuda.h
+40
-0
3rdparty/core-r22.12/src/model_config_utils.cc
3rdparty/core-r22.12/src/model_config_utils.cc
+2294
-0
3rdparty/core-r22.12/src/model_config_utils.h
3rdparty/core-r22.12/src/model_config_utils.h
+282
-0
3rdparty/core-r22.12/src/model_lifecycle.cc
3rdparty/core-r22.12/src/model_lifecycle.cc
+740
-0
3rdparty/core-r22.12/src/model_lifecycle.h
3rdparty/core-r22.12/src/model_lifecycle.h
+324
-0
3rdparty/core-r22.12/src/model_repository_manager.cc
3rdparty/core-r22.12/src/model_repository_manager.cc
+1602
-0
No files found.
Too many changes to show.
To preserve performance only
157 of 157+
files are displayed.
Plain diff
Email patch
3rdparty/core-r22.12/src/label_provider.cc
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "label_provider.h"
#include <iostream>
#include <iterator>
#include <sstream>
#include "filesystem.h"
namespace
triton
{
namespace
core
{
const
std
::
string
&
LabelProvider
::
GetLabel
(
const
std
::
string
&
name
,
size_t
index
)
const
{
static
const
std
::
string
not_found
;
auto
itr
=
label_map_
.
find
(
name
);
if
(
itr
==
label_map_
.
end
())
{
return
not_found
;
}
if
(
itr
->
second
.
size
()
<=
index
)
{
return
not_found
;
}
return
itr
->
second
[
index
];
}
Status
LabelProvider
::
AddLabels
(
const
std
::
string
&
name
,
const
std
::
string
&
filepath
)
{
std
::
string
label_file_contents
;
RETURN_IF_ERROR
(
ReadTextFile
(
filepath
,
&
label_file_contents
));
auto
p
=
label_map_
.
insert
(
std
::
make_pair
(
name
,
std
::
vector
<
std
::
string
>
()));
if
(
!
p
.
second
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"multiple label files for '"
+
name
+
"'"
);
}
auto
itr
=
p
.
first
;
std
::
istringstream
label_file_stream
(
label_file_contents
);
std
::
string
line
;
while
(
std
::
getline
(
label_file_stream
,
line
))
{
itr
->
second
.
push_back
(
line
);
}
return
Status
::
Success
;
}
const
std
::
vector
<
std
::
string
>&
LabelProvider
::
GetLabels
(
const
std
::
string
&
name
)
{
static
const
std
::
vector
<
std
::
string
>
not_found
;
auto
itr
=
label_map_
.
find
(
name
);
if
(
itr
==
label_map_
.
end
())
{
return
not_found
;
}
return
itr
->
second
;
}
Status
LabelProvider
::
AddLabels
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
labels
)
{
label_map_
.
emplace
(
name
,
labels
);
return
Status
::
Success
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/label_provider.h
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "constants.h"
#include "status.h"
namespace
triton
{
namespace
core
{
// Provides classification labels.
class
LabelProvider
{
public:
LabelProvider
()
=
default
;
// Return the label associated with 'name' for a given
// 'index'. Return empty string if no label is available.
const
std
::
string
&
GetLabel
(
const
std
::
string
&
name
,
size_t
index
)
const
;
// Associate with 'name' a set of labels initialized from a given
// 'filepath'. Within the file each label is specified on its own
// line. The first label (line 0) is the index-0 label, the second
// label (line 1) is the index-1 label, etc.
Status
AddLabels
(
const
std
::
string
&
name
,
const
std
::
string
&
filepath
);
// Return the labels associated with 'name'. Return empty vector if no labels
// are available.
const
std
::
vector
<
std
::
string
>&
GetLabels
(
const
std
::
string
&
name
);
// Associate with 'name' a set of 'labels'
Status
AddLabels
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>&
labels
);
private:
DISALLOW_COPY_AND_ASSIGN
(
LabelProvider
);
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
label_map_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/libtritonserver.ldscript
0 → 100644
View file @
b30f3cdb
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONSERVER_*;
TRITONBACKEND_*;
TRITONREPOAGENT_*;
local: *;
};
3rdparty/core-r22.12/src/memory.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "memory.h"
#include "pinned_memory_manager.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#include "cuda_memory_manager.h"
#endif // TRITON_ENABLE_GPU
namespace
triton
{
namespace
core
{
//
// MemoryReference
//
MemoryReference
::
MemoryReference
()
:
Memory
()
{}
const
char
*
MemoryReference
::
BufferAt
(
size_t
idx
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
{
if
(
idx
>=
buffer_
.
size
())
{
*
byte_size
=
0
;
*
memory_type
=
TRITONSERVER_MEMORY_CPU
;
*
memory_type_id
=
0
;
return
nullptr
;
}
*
memory_type
=
buffer_
[
idx
].
buffer_attributes_
.
MemoryType
();
*
memory_type_id
=
buffer_
[
idx
].
buffer_attributes_
.
MemoryTypeId
();
*
byte_size
=
buffer_
[
idx
].
buffer_attributes_
.
ByteSize
();
return
buffer_
[
idx
].
buffer_
;
}
const
char
*
MemoryReference
::
BufferAt
(
size_t
idx
,
BufferAttributes
**
buffer_attributes
)
{
if
(
idx
>=
buffer_
.
size
())
{
*
buffer_attributes
=
nullptr
;
return
nullptr
;
}
*
buffer_attributes
=
&
(
buffer_
[
idx
].
buffer_attributes_
);
return
buffer_
[
idx
].
buffer_
;
}
size_t
MemoryReference
::
AddBuffer
(
const
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
total_byte_size_
+=
byte_size
;
buffer_count_
++
;
buffer_
.
emplace_back
(
buffer
,
byte_size
,
memory_type
,
memory_type_id
);
return
buffer_
.
size
()
-
1
;
}
size_t
MemoryReference
::
AddBuffer
(
const
char
*
buffer
,
BufferAttributes
*
buffer_attributes
)
{
total_byte_size_
+=
buffer_attributes
->
ByteSize
();
buffer_count_
++
;
buffer_
.
emplace_back
(
buffer
,
buffer_attributes
);
return
buffer_
.
size
()
-
1
;
}
size_t
MemoryReference
::
AddBufferFront
(
const
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
total_byte_size_
+=
byte_size
;
buffer_count_
++
;
buffer_
.
emplace
(
buffer_
.
begin
(),
buffer
,
byte_size
,
memory_type
,
memory_type_id
);
return
buffer_
.
size
()
-
1
;
}
//
// MutableMemory
//
MutableMemory
::
MutableMemory
(
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
:
Memory
(),
buffer_
(
buffer
),
buffer_attributes_
(
BufferAttributes
(
byte_size
,
memory_type
,
memory_type_id
,
nullptr
))
{
total_byte_size_
=
byte_size
;
buffer_count_
=
(
byte_size
==
0
)
?
0
:
1
;
}
const
char
*
MutableMemory
::
BufferAt
(
size_t
idx
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
{
if
(
idx
!=
0
)
{
*
byte_size
=
0
;
*
memory_type
=
TRITONSERVER_MEMORY_CPU
;
*
memory_type_id
=
0
;
return
nullptr
;
}
*
byte_size
=
total_byte_size_
;
*
memory_type
=
buffer_attributes_
.
MemoryType
();
*
memory_type_id
=
buffer_attributes_
.
MemoryTypeId
();
return
buffer_
;
}
const
char
*
MutableMemory
::
BufferAt
(
size_t
idx
,
BufferAttributes
**
buffer_attributes
)
{
if
(
idx
!=
0
)
{
*
buffer_attributes
=
nullptr
;
return
nullptr
;
}
*
buffer_attributes
=
&
buffer_attributes_
;
return
buffer_
;
}
char
*
MutableMemory
::
MutableBuffer
(
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
{
if
(
memory_type
!=
nullptr
)
{
*
memory_type
=
buffer_attributes_
.
MemoryType
();
}
if
(
memory_type_id
!=
nullptr
)
{
*
memory_type_id
=
buffer_attributes_
.
MemoryTypeId
();
}
return
buffer_
;
}
//
// AllocatedMemory
//
AllocatedMemory
::
AllocatedMemory
(
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
:
MutableMemory
(
nullptr
,
byte_size
,
memory_type
,
memory_type_id
)
{
if
(
total_byte_size_
!=
0
)
{
// Allocate memory with the following fallback policy:
// CUDA memory -> pinned system memory -> non-pinned system memory
switch
(
buffer_attributes_
.
MemoryType
())
{
#ifdef TRITON_ENABLE_GPU
case
TRITONSERVER_MEMORY_GPU
:
{
auto
status
=
CudaMemoryManager
::
Alloc
(
(
void
**
)
&
buffer_
,
total_byte_size_
,
buffer_attributes_
.
MemoryTypeId
());
if
(
!
status
.
IsOk
())
{
static
bool
warning_logged
=
false
;
if
(
!
warning_logged
)
{
LOG_WARNING
<<
status
.
Message
()
<<
", falling back to pinned system memory"
;
warning_logged
=
true
;
}
goto
pinned_memory_allocation
;
}
break
;
}
pinned_memory_allocation:
#endif // TRITON_ENABLE_GPU
default:
{
TRITONSERVER_MemoryType
memory_type
=
buffer_attributes_
.
MemoryType
();
auto
status
=
PinnedMemoryManager
::
Alloc
(
(
void
**
)
&
buffer_
,
total_byte_size_
,
&
memory_type
,
true
);
buffer_attributes_
.
SetMemoryType
(
memory_type
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
status
.
Message
();
buffer_
=
nullptr
;
}
break
;
}
}
}
total_byte_size_
=
(
buffer_
==
nullptr
)
?
0
:
total_byte_size_
;
}
AllocatedMemory
::~
AllocatedMemory
()
{
if
(
buffer_
!=
nullptr
)
{
switch
(
buffer_attributes_
.
MemoryType
())
{
case
TRITONSERVER_MEMORY_GPU
:
{
#ifdef TRITON_ENABLE_GPU
auto
status
=
CudaMemoryManager
::
Free
(
buffer_
,
buffer_attributes_
.
MemoryTypeId
());
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
status
.
Message
();
}
#endif // TRITON_ENABLE_GPU
break
;
}
default:
{
auto
status
=
PinnedMemoryManager
::
Free
(
buffer_
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
status
.
Message
();
buffer_
=
nullptr
;
}
break
;
}
}
buffer_
=
nullptr
;
}
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/memory.h
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <vector>
#include "buffer_attributes.h"
#include "constants.h"
#include "status.h"
namespace
triton
{
namespace
core
{
//
// Memory used to access data in inference requests
//
class
Memory
{
public:
// Get the 'idx'-th data block in the buffer. Using index to avoid
// maintaining internal state such that one buffer can be shared
// across multiple providers.
// 'idx' zero base index. Valid indices are continuous.
// 'byte_size' returns the byte size of the chunk of bytes.
// 'memory_type' returns the memory type of the chunk of bytes.
// 'memory_type_id' returns the memory type id of the chunk of bytes.
// Return the pointer to the data block. Returns nullptr if 'idx' is
// out of range
virtual
const
char
*
BufferAt
(
size_t
idx
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
=
0
;
// Similar to the above BufferAt but with BufferAttributes.
virtual
const
char
*
BufferAt
(
size_t
idx
,
BufferAttributes
**
buffer_attributes
)
=
0
;
// Get the number of contiguous buffers composing the memory.
size_t
BufferCount
()
const
{
return
buffer_count_
;
}
// Return the total byte size of the data buffer
size_t
TotalByteSize
()
const
{
return
total_byte_size_
;
}
protected:
Memory
()
:
total_byte_size_
(
0
),
buffer_count_
(
0
)
{}
size_t
total_byte_size_
;
size_t
buffer_count_
;
};
//
// MemoryReference
//
class
MemoryReference
:
public
Memory
{
public:
// Create a read-only data buffer as a reference to other data buffer
MemoryReference
();
//\see Memory::BufferAt()
const
char
*
BufferAt
(
size_t
idx
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
override
;
const
char
*
BufferAt
(
size_t
idx
,
BufferAttributes
**
buffer_attributes
)
override
;
// Add a 'buffer' with 'byte_size' as part of this data buffer
// Return the index of the buffer
size_t
AddBuffer
(
const
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
size_t
AddBuffer
(
const
char
*
buffer
,
BufferAttributes
*
buffer_attributes
);
// Add a 'buffer' with 'byte_size' as part of this data buffer in the front
// Return the index of the buffer
size_t
AddBufferFront
(
const
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
private:
struct
Block
{
Block
(
const
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
:
buffer_
(
buffer
),
buffer_attributes_
(
BufferAttributes
(
byte_size
,
memory_type
,
memory_type_id
,
nullptr
))
{
}
Block
(
const
char
*
buffer
,
BufferAttributes
*
buffer_attributes
)
:
buffer_
(
buffer
),
buffer_attributes_
(
*
buffer_attributes
)
{
}
const
char
*
buffer_
;
BufferAttributes
buffer_attributes_
;
};
std
::
vector
<
Block
>
buffer_
;
};
//
// MutableMemory
//
class
MutableMemory
:
public
Memory
{
public:
// Create a mutable data buffer referencing to other data buffer.
MutableMemory
(
char
*
buffer
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
virtual
~
MutableMemory
()
{}
//\see Memory::BufferAt()
const
char
*
BufferAt
(
size_t
idx
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
override
;
//\see Memory::BufferAt()
const
char
*
BufferAt
(
size_t
idx
,
BufferAttributes
**
buffer_attributes
)
override
;
// Return a pointer to the base address of the mutable buffer. If
// non-null 'memory_type' returns the memory type of the chunk of
// bytes. If non-null 'memory_type_id' returns the memory type id of
// the chunk of bytes.
char
*
MutableBuffer
(
TRITONSERVER_MemoryType
*
memory_type
=
nullptr
,
int64_t
*
memory_type_id
=
nullptr
);
DISALLOW_COPY_AND_ASSIGN
(
MutableMemory
);
protected:
MutableMemory
()
:
Memory
()
{}
char
*
buffer_
;
BufferAttributes
buffer_attributes_
;
};
//
// AllocatedMemory
//
class
AllocatedMemory
:
public
MutableMemory
{
public:
// Create a continuous data buffer with 'byte_size', 'memory_type' and
// 'memory_type_id'. Note that the buffer may be created on different memeory
// type and memory type id if the original request type and id can not be
// satisfied, thus the function caller should always check the actual memory
// type and memory type id before use.
AllocatedMemory
(
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
~
AllocatedMemory
()
override
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/metric_family.cc
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef TRITON_ENABLE_METRICS
#include "metric_family.h"
#include "metrics.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
//
// Implementation for TRITONSERVER_MetricFamily.
//
MetricFamily
::
MetricFamily
(
TRITONSERVER_MetricKind
kind
,
const
char
*
name
,
const
char
*
description
)
{
auto
registry
=
Metrics
::
GetRegistry
();
switch
(
kind
)
{
case
TRITONSERVER_METRIC_KIND_COUNTER
:
family_
=
reinterpret_cast
<
void
*>
(
&
prometheus
::
BuildCounter
()
.
Name
(
name
)
.
Help
(
description
)
.
Register
(
*
registry
));
break
;
case
TRITONSERVER_METRIC_KIND_GAUGE
:
family_
=
reinterpret_cast
<
void
*>
(
&
prometheus
::
BuildGauge
()
.
Name
(
name
)
.
Help
(
description
)
.
Register
(
*
registry
));
break
;
default:
throw
std
::
invalid_argument
(
"Unsupported kind passed to MetricFamily constructor."
);
}
kind_
=
kind
;
}
void
*
MetricFamily
::
Add
(
std
::
map
<
std
::
string
,
std
::
string
>
label_map
,
Metric
*
metric
)
{
void
*
prom_metric
=
nullptr
;
switch
(
kind_
)
{
case
TRITONSERVER_METRIC_KIND_COUNTER
:
{
auto
counter_family_ptr
=
reinterpret_cast
<
prometheus
::
Family
<
prometheus
::
Counter
>*>
(
family_
);
auto
counter_ptr
=
&
counter_family_ptr
->
Add
(
label_map
);
prom_metric
=
reinterpret_cast
<
void
*>
(
counter_ptr
);
break
;
}
case
TRITONSERVER_METRIC_KIND_GAUGE
:
{
auto
gauge_family_ptr
=
reinterpret_cast
<
prometheus
::
Family
<
prometheus
::
Gauge
>*>
(
family_
);
auto
gauge_ptr
=
&
gauge_family_ptr
->
Add
(
label_map
);
prom_metric
=
reinterpret_cast
<
void
*>
(
gauge_ptr
);
break
;
}
default:
throw
std
::
invalid_argument
(
"Unsupported family kind passed to Metric constructor."
);
}
std
::
lock_guard
<
std
::
mutex
>
lk
(
metric_mtx_
);
++
prom_metric_ref_cnt_
[
prom_metric
];
child_metrics_
.
insert
(
metric
);
return
prom_metric
;
}
void
MetricFamily
::
Remove
(
void
*
prom_metric
,
Metric
*
metric
)
{
{
// Remove reference to dependent Metric object
std
::
lock_guard
<
std
::
mutex
>
lk
(
metric_mtx_
);
child_metrics_
.
erase
(
metric
);
}
if
(
prom_metric
==
nullptr
)
{
return
;
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
metric_mtx_
);
const
auto
it
=
prom_metric_ref_cnt_
.
find
(
prom_metric
);
if
(
it
!=
prom_metric_ref_cnt_
.
end
())
{
--
it
->
second
;
if
(
it
->
second
==
0
)
{
prom_metric_ref_cnt_
.
erase
(
it
);
}
else
{
// Done as it is not the last reference
return
;
}
}
}
switch
(
kind_
)
{
case
TRITONSERVER_METRIC_KIND_COUNTER
:
{
auto
counter_family_ptr
=
reinterpret_cast
<
prometheus
::
Family
<
prometheus
::
Counter
>*>
(
family_
);
auto
counter_ptr
=
reinterpret_cast
<
prometheus
::
Counter
*>
(
prom_metric
);
counter_family_ptr
->
Remove
(
counter_ptr
);
break
;
}
case
TRITONSERVER_METRIC_KIND_GAUGE
:
{
auto
gauge_family_ptr
=
reinterpret_cast
<
prometheus
::
Family
<
prometheus
::
Gauge
>*>
(
family_
);
auto
gauge_ptr
=
reinterpret_cast
<
prometheus
::
Gauge
*>
(
prom_metric
);
gauge_family_ptr
->
Remove
(
gauge_ptr
);
break
;
}
default:
// Invalid kind should be caught in constructor
LOG_ERROR
<<
"Unsupported kind in Metric destructor."
;
break
;
}
}
void
MetricFamily
::
InvalidateReferences
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
metric_mtx_
);
for
(
auto
&
metric
:
child_metrics_
)
{
if
(
metric
!=
nullptr
)
{
metric
->
Invalidate
();
}
}
child_metrics_
.
clear
();
}
MetricFamily
::~
MetricFamily
()
{
if
(
NumMetrics
()
>
0
)
{
LOG_WARNING
<<
"MetricFamily was deleted before its child Metrics, this "
"should not happen. Make sure to delete all child Metrics "
"before deleting their MetricFamily."
;
}
InvalidateReferences
();
// DLIS-4072: Support for removing metric families from registry
}
//
// Implementation for TRITONSERVER_Metric.
//
Metric
::
Metric
(
TRITONSERVER_MetricFamily
*
family
,
std
::
vector
<
const
InferenceParameter
*>
labels
)
{
family_
=
reinterpret_cast
<
MetricFamily
*>
(
family
);
kind_
=
family_
->
Kind
();
// Create map of labels from InferenceParameters
std
::
map
<
std
::
string
,
std
::
string
>
label_map
;
for
(
const
auto
&
param
:
labels
)
{
if
(
param
->
Type
()
!=
TRITONSERVER_PARAMETER_STRING
)
{
throw
std
::
invalid_argument
(
"Parameter ["
+
param
->
Name
()
+
"] must have a type of TRITONSERVER_PARAMETER_STRING to be "
"added as a label."
);
}
label_map
[
param
->
Name
()]
=
std
::
string
(
reinterpret_cast
<
const
char
*>
(
param
->
ValuePointer
()));
}
metric_
=
family_
->
Add
(
label_map
,
this
);
}
Metric
::~
Metric
()
{
if
(
family_
!=
nullptr
)
{
family_
->
Remove
(
metric_
,
this
);
}
else
{
LOG_WARNING
<<
"Corresponding MetricFamily was deleted before this Metric, "
"this should not happen. Make sure to delete a Metric "
"before deleting its MetricFamily."
;
}
// Catch lifetime management / invalid reference issues
Invalidate
();
}
void
Metric
::
Invalidate
()
{
family_
=
nullptr
;
metric_
=
nullptr
;
}
TRITONSERVER_Error
*
Metric
::
Value
(
double
*
value
)
{
if
(
metric_
==
nullptr
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INTERNAL
,
"Could not get metric value. Metric has been invalidated."
);
}
switch
(
kind_
)
{
case
TRITONSERVER_METRIC_KIND_COUNTER
:
{
auto
counter_ptr
=
reinterpret_cast
<
prometheus
::
Counter
*>
(
metric_
);
LOG_VERBOSE
(
1
)
<<
"SETTING COUNTER METRIC FROM: "
<<
*
value
<<
" to "
<<
counter_ptr
->
Value
();
*
value
=
counter_ptr
->
Value
();
break
;
}
case
TRITONSERVER_METRIC_KIND_GAUGE
:
{
auto
gauge_ptr
=
reinterpret_cast
<
prometheus
::
Gauge
*>
(
metric_
);
LOG_VERBOSE
(
1
)
<<
"SETTING GAUGE METRIC FROM: "
<<
*
value
<<
" to "
<<
gauge_ptr
->
Value
();
*
value
=
gauge_ptr
->
Value
();
break
;
}
default:
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_UNSUPPORTED
,
"Unsupported TRITONSERVER_MetricKind"
);
}
return
nullptr
;
// Success
}
TRITONSERVER_Error
*
Metric
::
Increment
(
double
value
)
{
if
(
metric_
==
nullptr
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INTERNAL
,
"Could not increment metric value. Metric has been invalidated."
);
}
switch
(
kind_
)
{
case
TRITONSERVER_METRIC_KIND_COUNTER
:
{
if
(
value
<
0.0
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"TRITONSERVER_METRIC_KIND_COUNTER can only be incremented "
"monotonically by non-negative values."
);
}
auto
counter_ptr
=
reinterpret_cast
<
prometheus
::
Counter
*>
(
metric_
);
counter_ptr
->
Increment
(
value
);
break
;
}
case
TRITONSERVER_METRIC_KIND_GAUGE
:
{
auto
gauge_ptr
=
reinterpret_cast
<
prometheus
::
Gauge
*>
(
metric_
);
// Gauge::Increment works for both positive and negative values as of
// prometheus-cpp v1.0 but for now on v0.7 we defer call to
// Increment/Decrement based on the sign of value
// https://github.com/jupp0r/prometheus-cpp/blob/master/core/src/gauge.cc
if
(
value
<
0.0
)
{
gauge_ptr
->
Decrement
(
-
1.0
*
value
);
}
else
{
gauge_ptr
->
Increment
(
value
);
}
break
;
}
default:
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_UNSUPPORTED
,
"Unsupported TRITONSERVER_MetricKind"
);
}
return
nullptr
;
// Success
}
TRITONSERVER_Error
*
Metric
::
Set
(
double
value
)
{
if
(
metric_
==
nullptr
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INTERNAL
,
"Could not set metric value. Metric has been invalidated."
);
}
switch
(
kind_
)
{
case
TRITONSERVER_METRIC_KIND_COUNTER
:
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_UNSUPPORTED
,
"TRITONSERVER_METRIC_KIND_COUNTER does not support Set"
);
}
case
TRITONSERVER_METRIC_KIND_GAUGE
:
{
auto
gauge_ptr
=
reinterpret_cast
<
prometheus
::
Gauge
*>
(
metric_
);
gauge_ptr
->
Set
(
value
);
break
;
}
default:
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_UNSUPPORTED
,
"Unsupported TRITONSERVER_MetricKind"
);
}
return
nullptr
;
// Success
}
}}
// namespace triton::core
#endif // TRITON_ENABLE_METRICS
3rdparty/core-r22.12/src/metric_family.h
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_METRICS
#include <mutex>
#include <set>
#include <unordered_map>
#include "infer_parameter.h"
#include "prometheus/registry.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
//
// Implementation for TRITONSERVER_MetricFamily.
//
class
Metric
;
class
MetricFamily
{
public:
MetricFamily
(
TRITONSERVER_MetricKind
kind
,
const
char
*
name
,
const
char
*
description
);
~
MetricFamily
();
void
*
Family
()
const
{
return
family_
;
}
TRITONSERVER_MetricKind
Kind
()
const
{
return
kind_
;
}
void
*
Add
(
std
::
map
<
std
::
string
,
std
::
string
>
label_map
,
Metric
*
metric
);
void
Remove
(
void
*
prom_metric
,
Metric
*
metric
);
int
NumMetrics
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
metric_mtx_
);
return
child_metrics_
.
size
();
}
private:
// If a MetricFamily is deleted before its dependent Metric, we want to
// invalidate the reference so we don't access invalid memory.
void
InvalidateReferences
();
void
*
family_
;
TRITONSERVER_MetricKind
kind_
;
// Synchronize access of related metric objects
std
::
mutex
metric_mtx_
;
// Prometheus returns the existing metric pointer if the metric with the same
// set of labels are requested, as a result, different Metric objects may
// refer to the same prometheus metric. So we must track the reference count
// of the metric and request prometheus to remove it only when all references
// are released.
std
::
unordered_map
<
void
*
,
size_t
>
prom_metric_ref_cnt_
;
// Maintain references to metrics created from this metric family to
// invalidate their references if a family is deleted before its metric
std
::
set
<
Metric
*>
child_metrics_
;
};
//
// Implementation for TRITONSERVER_Metric.
//
class
Metric
{
public:
Metric
(
TRITONSERVER_MetricFamily
*
family
,
std
::
vector
<
const
InferenceParameter
*>
labels
);
~
Metric
();
MetricFamily
*
Family
()
const
{
return
family_
;
}
TRITONSERVER_MetricKind
Kind
()
const
{
return
kind_
;
}
TRITONSERVER_Error
*
Value
(
double
*
value
);
TRITONSERVER_Error
*
Increment
(
double
value
);
TRITONSERVER_Error
*
Set
(
double
value
);
// If a MetricFamily is deleted before its dependent Metric, we want to
// invalidate the references so we don't access invalid memory.
void
Invalidate
();
private:
void
*
metric_
;
MetricFamily
*
family_
;
TRITONSERVER_MetricKind
kind_
;
};
}}
// namespace triton::core
#endif // TRITON_ENABLE_METRICS
3rdparty/core-r22.12/src/metric_model_reporter.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "metric_model_reporter.h"
#ifdef TRITON_ENABLE_METRICS
#include "constants.h"
#include "metrics.h"
namespace
triton
{
namespace
core
{
Status
MetricModelReporter
::
Create
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
const
int
device
,
const
triton
::
common
::
MetricTagsMap
&
model_tags
,
std
::
shared_ptr
<
MetricModelReporter
>*
metric_model_reporter
)
{
static
std
::
mutex
mtx
;
static
std
::
unordered_map
<
size_t
,
std
::
weak_ptr
<
MetricModelReporter
>>
reporter_map
;
std
::
map
<
std
::
string
,
std
::
string
>
labels
;
GetMetricLabels
(
&
labels
,
model_name
,
model_version
,
device
,
model_tags
);
auto
hash_labels
=
Metrics
::
HashLabels
(
labels
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx
);
const
auto
&
itr
=
reporter_map
.
find
(
hash_labels
);
if
(
itr
!=
reporter_map
.
end
())
{
// Found in map. If the weak_ptr is still valid that means that
// there are other models using the reporter and we just reuse that
// same reporter. If the weak_ptr is not valid then we need to remove
// the weak_ptr from the map and create the reporter again.
*
metric_model_reporter
=
itr
->
second
.
lock
();
if
(
*
metric_model_reporter
!=
nullptr
)
{
return
Status
::
Success
;
}
reporter_map
.
erase
(
itr
);
}
metric_model_reporter
->
reset
(
new
MetricModelReporter
(
model_name
,
model_version
,
device
,
model_tags
));
reporter_map
.
insert
({
hash_labels
,
*
metric_model_reporter
});
return
Status
::
Success
;
}
MetricModelReporter
::
MetricModelReporter
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
const
int
device
,
const
triton
::
common
::
MetricTagsMap
&
model_tags
)
{
std
::
map
<
std
::
string
,
std
::
string
>
labels
;
GetMetricLabels
(
&
labels
,
model_name
,
model_version
,
device
,
model_tags
);
metric_inf_success_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceSuccess
(),
labels
);
metric_inf_failure_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceFailure
(),
labels
);
metric_inf_count_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceCount
(),
labels
);
metric_inf_exec_count_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceExecutionCount
(),
labels
);
metric_inf_request_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceRequestDuration
(),
labels
);
metric_inf_queue_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceQueueDuration
(),
labels
);
metric_inf_compute_input_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceComputeInputDuration
(),
labels
);
metric_inf_compute_infer_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceComputeInferDuration
(),
labels
);
metric_inf_compute_output_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyInferenceComputeOutputDuration
(),
labels
);
metric_cache_hit_count_
=
CreateCounterMetric
(
Metrics
::
FamilyCacheHitCount
(),
labels
);
metric_cache_hit_lookup_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyCacheHitLookupDuration
(),
labels
);
metric_cache_miss_count_
=
CreateCounterMetric
(
Metrics
::
FamilyCacheMissCount
(),
labels
);
metric_cache_miss_lookup_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyCacheMissLookupDuration
(),
labels
);
metric_cache_miss_insertion_duration_us_
=
CreateCounterMetric
(
Metrics
::
FamilyCacheMissInsertionDuration
(),
labels
);
}
MetricModelReporter
::~
MetricModelReporter
()
{
Metrics
::
FamilyInferenceSuccess
().
Remove
(
metric_inf_success_
);
Metrics
::
FamilyInferenceFailure
().
Remove
(
metric_inf_failure_
);
Metrics
::
FamilyInferenceCount
().
Remove
(
metric_inf_count_
);
Metrics
::
FamilyInferenceExecutionCount
().
Remove
(
metric_inf_exec_count_
);
Metrics
::
FamilyInferenceRequestDuration
().
Remove
(
metric_inf_request_duration_us_
);
Metrics
::
FamilyInferenceQueueDuration
().
Remove
(
metric_inf_queue_duration_us_
);
Metrics
::
FamilyInferenceComputeInputDuration
().
Remove
(
metric_inf_compute_input_duration_us_
);
Metrics
::
FamilyInferenceComputeInferDuration
().
Remove
(
metric_inf_compute_infer_duration_us_
);
Metrics
::
FamilyInferenceComputeOutputDuration
().
Remove
(
metric_inf_compute_output_duration_us_
);
Metrics
::
FamilyCacheHitCount
().
Remove
(
metric_cache_hit_count_
);
Metrics
::
FamilyCacheHitLookupDuration
().
Remove
(
metric_cache_hit_lookup_duration_us_
);
Metrics
::
FamilyCacheMissCount
().
Remove
(
metric_cache_miss_count_
);
Metrics
::
FamilyCacheMissInsertionDuration
().
Remove
(
metric_cache_miss_insertion_duration_us_
);
}
void
MetricModelReporter
::
GetMetricLabels
(
std
::
map
<
std
::
string
,
std
::
string
>*
labels
,
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
const
int
device
,
const
triton
::
common
::
MetricTagsMap
&
model_tags
)
{
labels
->
insert
(
std
::
map
<
std
::
string
,
std
::
string
>::
value_type
(
std
::
string
(
kMetricsLabelModelName
),
model_name
));
labels
->
insert
(
std
::
map
<
std
::
string
,
std
::
string
>::
value_type
(
std
::
string
(
kMetricsLabelModelVersion
),
std
::
to_string
(
model_version
)));
for
(
const
auto
&
tag
:
model_tags
)
{
labels
->
insert
(
std
::
map
<
std
::
string
,
std
::
string
>::
value_type
(
"_"
+
tag
.
first
,
tag
.
second
));
}
// 'device' can be < 0 to indicate that the GPU is not known. In
// that case use a metric that doesn't have the gpu_uuid label.
if
(
device
>=
0
)
{
std
::
string
uuid
;
if
(
Metrics
::
UUIDForCudaDevice
(
device
,
&
uuid
))
{
labels
->
insert
(
std
::
map
<
std
::
string
,
std
::
string
>::
value_type
(
std
::
string
(
kMetricsLabelGpuUuid
),
uuid
));
}
}
}
prometheus
::
Counter
*
MetricModelReporter
::
CreateCounterMetric
(
prometheus
::
Family
<
prometheus
::
Counter
>&
family
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
labels
)
{
return
&
family
.
Add
(
labels
);
}
}}
// namespace triton::core
#endif // TRITON_ENABLE_METRICS
3rdparty/core-r22.12/src/metric_model_reporter.h
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "status.h"
#include "triton/common/model_config.h"
#ifdef TRITON_ENABLE_METRICS
#include "prometheus/registry.h"
#endif // TRITON_ENABLE_METRICS
namespace
triton
{
namespace
core
{
//
// Interface for a metric reporter for a given version of a model.
//
class
MetricModelReporter
{
public:
#ifdef TRITON_ENABLE_METRICS
static
Status
Create
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
const
int
device
,
const
triton
::
common
::
MetricTagsMap
&
model_tags
,
std
::
shared_ptr
<
MetricModelReporter
>*
metric_model_reporter
);
~
MetricModelReporter
();
// Get a metric for the given model, version and GPU index.
prometheus
::
Counter
&
MetricInferenceSuccess
()
const
{
return
*
metric_inf_success_
;
}
prometheus
::
Counter
&
MetricInferenceFailure
()
const
{
return
*
metric_inf_failure_
;
}
prometheus
::
Counter
&
MetricInferenceCount
()
const
{
return
*
metric_inf_count_
;
}
prometheus
::
Counter
&
MetricInferenceExecutionCount
()
const
{
return
*
metric_inf_exec_count_
;
}
prometheus
::
Counter
&
MetricInferenceRequestDuration
()
const
{
return
*
metric_inf_request_duration_us_
;
}
prometheus
::
Counter
&
MetricInferenceQueueDuration
()
const
{
return
*
metric_inf_queue_duration_us_
;
}
prometheus
::
Counter
&
MetricInferenceComputeInputDuration
()
const
{
return
*
metric_inf_compute_input_duration_us_
;
}
prometheus
::
Counter
&
MetricInferenceComputeInferDuration
()
const
{
return
*
metric_inf_compute_infer_duration_us_
;
}
prometheus
::
Counter
&
MetricInferenceComputeOutputDuration
()
const
{
return
*
metric_inf_compute_output_duration_us_
;
}
prometheus
::
Counter
&
MetricCacheHitCount
()
const
{
return
*
metric_cache_hit_count_
;
}
prometheus
::
Counter
&
MetricCacheHitLookupDuration
()
const
{
return
*
metric_cache_hit_lookup_duration_us_
;
}
prometheus
::
Counter
&
MetricCacheMissCount
()
const
{
return
*
metric_cache_miss_count_
;
}
prometheus
::
Counter
&
MetricCacheMissLookupDuration
()
const
{
return
*
metric_cache_miss_lookup_duration_us_
;
}
prometheus
::
Counter
&
MetricCacheMissInsertionDuration
()
const
{
return
*
metric_cache_miss_insertion_duration_us_
;
}
private:
MetricModelReporter
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
const
int
device
,
const
triton
::
common
::
MetricTagsMap
&
model_tags
);
static
void
GetMetricLabels
(
std
::
map
<
std
::
string
,
std
::
string
>*
labels
,
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
const
int
device
,
const
triton
::
common
::
MetricTagsMap
&
model_tags
);
prometheus
::
Counter
*
CreateCounterMetric
(
prometheus
::
Family
<
prometheus
::
Counter
>&
family
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
labels
);
prometheus
::
Counter
*
metric_inf_success_
;
prometheus
::
Counter
*
metric_inf_failure_
;
prometheus
::
Counter
*
metric_inf_count_
;
prometheus
::
Counter
*
metric_inf_exec_count_
;
prometheus
::
Counter
*
metric_inf_request_duration_us_
;
prometheus
::
Counter
*
metric_inf_queue_duration_us_
;
prometheus
::
Counter
*
metric_inf_compute_input_duration_us_
;
prometheus
::
Counter
*
metric_inf_compute_infer_duration_us_
;
prometheus
::
Counter
*
metric_inf_compute_output_duration_us_
;
prometheus
::
Counter
*
metric_cache_hit_count_
;
prometheus
::
Counter
*
metric_cache_hit_lookup_duration_us_
;
prometheus
::
Counter
*
metric_cache_miss_count_
;
prometheus
::
Counter
*
metric_cache_miss_lookup_duration_us_
;
prometheus
::
Counter
*
metric_cache_miss_insertion_duration_us_
;
#endif // TRITON_ENABLE_METRICS
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/metrics.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#ifdef TRITON_ENABLE_METRICS
#include "metrics.h"
#include <thread>
#include "constants.h"
#include "prometheus/detail/utils.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_METRICS_GPU
#include <cuda_runtime_api.h>
#include <dcgm_agent.h>
#include <cstring>
#include <set>
#include <string>
#endif // TRITON_ENABLE_METRICS_GPU
namespace
triton
{
namespace
core
{
Metrics
::
Metrics
()
:
registry_
(
std
::
make_shared
<
prometheus
::
Registry
>
()),
serializer_
(
new
prometheus
::
TextSerializer
()),
inf_success_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_request_success"
)
.
Help
(
"Number of successful inference requests, all batch sizes"
)
.
Register
(
*
registry_
)),
inf_failure_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_request_failure"
)
.
Help
(
"Number of failed inference requests, all batch sizes"
)
.
Register
(
*
registry_
)),
inf_count_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_count"
)
.
Help
(
"Number of inferences performed (does not "
"include cached requests)"
)
.
Register
(
*
registry_
)),
inf_count_exec_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_exec_count"
)
.
Help
(
"Number of model executions performed "
"(does not include cached requests)"
)
.
Register
(
*
registry_
)),
inf_request_duration_us_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_request_duration_us"
)
.
Help
(
"Cumulative inference request duration in microseconds "
"(includes cached requests)"
)
.
Register
(
*
registry_
)),
inf_queue_duration_us_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_queue_duration_us"
)
.
Help
(
"Cumulative inference queuing duration in microseconds "
"(includes cached requests)"
)
.
Register
(
*
registry_
)),
inf_compute_input_duration_us_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_compute_input_duration_us"
)
.
Help
(
"Cumulative compute input duration in microseconds (does "
"not include cached requests)"
)
.
Register
(
*
registry_
)),
inf_compute_infer_duration_us_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_compute_infer_duration_us"
)
.
Help
(
"Cumulative compute inference duration in microseconds "
"(does not include cached requests)"
)
.
Register
(
*
registry_
)),
inf_compute_output_duration_us_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_inference_compute_output_duration_us"
)
.
Help
(
"Cumulative inference compute output duration in "
"microseconds (does not include cached requests)"
)
.
Register
(
*
registry_
)),
cache_num_entries_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_num_entries"
)
.
Help
(
"Number of responses stored in response cache"
)
.
Register
(
*
registry_
)),
cache_num_lookups_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_num_lookups"
)
.
Help
(
"Number of cache lookups in response cache"
)
.
Register
(
*
registry_
)),
cache_num_hits_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_num_hits"
)
.
Help
(
"Number of cache hits in response cache"
)
.
Register
(
*
registry_
)),
cache_num_misses_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_num_misses"
)
.
Help
(
"Number of cache misses in response cache"
)
.
Register
(
*
registry_
)),
cache_num_evictions_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_num_evictions"
)
.
Help
(
"Number of cache evictions in response cache"
)
.
Register
(
*
registry_
)),
cache_lookup_duration_us_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_lookup_duration"
)
.
Help
(
"Total cache lookup duration (hit and miss), in microseconds"
)
.
Register
(
*
registry_
)),
cache_insertion_duration_us_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_insertion_duration"
)
.
Help
(
"Total cache insertion duration, in microseconds"
)
.
Register
(
*
registry_
)),
cache_util_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cache_util"
)
.
Help
(
"Cache utilization [0.0 - 1.0]"
)
.
Register
(
*
registry_
)),
// Per-model cache metric families
cache_num_hits_model_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_cache_num_hits_per_model"
)
.
Help
(
"Number of cache hits per model"
)
.
Register
(
*
registry_
)),
cache_hit_lookup_duration_us_model_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_cache_hit_lookup_duration_per_model"
)
.
Help
(
"Total cache hit lookup duration per model, in microseconds"
)
.
Register
(
*
registry_
)),
cache_num_misses_model_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_cache_num_misses_per_model"
)
.
Help
(
"Number of cache misses per model"
)
.
Register
(
*
registry_
)),
cache_miss_lookup_duration_us_model_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_cache_miss_lookup_duration_per_model"
)
.
Help
(
"Total cache miss lookup duration per model, in microseconds"
)
.
Register
(
*
registry_
)),
cache_miss_insertion_duration_us_model_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_cache_miss_insertion_duration_per_model"
)
.
Help
(
"Total cache miss insertion duration per model, in "
"microseconds"
)
.
Register
(
*
registry_
)),
#ifdef TRITON_ENABLE_METRICS_GPU
gpu_utilization_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_gpu_utilization"
)
.
Help
(
"GPU utilization rate [0.0 - 1.0)"
)
.
Register
(
*
registry_
)),
gpu_memory_total_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_gpu_memory_total_bytes"
)
.
Help
(
"GPU total memory, in bytes"
)
.
Register
(
*
registry_
)),
gpu_memory_used_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_gpu_memory_used_bytes"
)
.
Help
(
"GPU used memory, in bytes"
)
.
Register
(
*
registry_
)),
gpu_power_usage_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_gpu_power_usage"
)
.
Help
(
"GPU power usage in watts"
)
.
Register
(
*
registry_
)),
gpu_power_limit_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_gpu_power_limit"
)
.
Help
(
"GPU power management limit in watts"
)
.
Register
(
*
registry_
)),
gpu_energy_consumption_family_
(
prometheus
::
BuildCounter
()
.
Name
(
"nv_energy_consumption"
)
.
Help
(
"GPU energy consumption in joules since the Triton Server "
"started"
)
.
Register
(
*
registry_
)),
#endif // TRITON_ENABLE_METRICS_GPU
#ifdef TRITON_ENABLE_METRICS_CPU
cpu_utilization_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cpu_utilization"
)
.
Help
(
"CPU utilization rate [0.0 - 1.0]"
)
.
Register
(
*
registry_
)),
cpu_memory_total_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cpu_memory_total_bytes"
)
.
Help
(
"CPU total memory (RAM), in bytes"
)
.
Register
(
*
registry_
)),
cpu_memory_used_family_
(
prometheus
::
BuildGauge
()
.
Name
(
"nv_cpu_memory_used_bytes"
)
.
Help
(
"CPU used memory (RAM), in bytes"
)
.
Register
(
*
registry_
)),
#endif // TRITON_ENABLE_METRICS_CPU
metrics_enabled_
(
false
),
gpu_metrics_enabled_
(
false
),
cpu_metrics_enabled_
(
false
),
cache_metrics_enabled_
(
false
),
metrics_interval_ms_
(
2000
)
{
}
static
prometheus
::
detail
::
LabelHasher
label_hasher_
;
size_t
Metrics
::
HashLabels
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
labels
)
{
return
label_hasher_
(
labels
);
}
Metrics
::~
Metrics
()
{
// Signal the cache thread to exit and then wait for it...
if
(
poll_thread_
!=
nullptr
)
{
poll_thread_exit_
.
store
(
true
);
poll_thread_
->
join
();
#ifdef TRITON_ENABLE_METRICS_GPU
if
(
dcgm_metadata_
.
dcgm_initialized_
)
{
dcgmReturn_t
derr
;
// Group destroy will return an error if groupId invalid or dcgm not
// initialized or configured correctly
derr
=
dcgmGroupDestroy
(
dcgm_metadata_
.
dcgm_handle_
,
dcgm_metadata_
.
groupId_
);
if
(
derr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"Unable to destroy DCGM group: "
<<
errorString
(
derr
);
}
// Stop and shutdown DCGM
if
(
dcgm_metadata_
.
standalone_
)
{
derr
=
dcgmDisconnect
(
dcgm_metadata_
.
dcgm_handle_
);
}
else
{
derr
=
dcgmStopEmbedded
(
dcgm_metadata_
.
dcgm_handle_
);
}
if
(
derr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"Unable to stop DCGM: "
<<
errorString
(
derr
);
}
derr
=
dcgmShutdown
();
if
(
derr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"Unable to shutdown DCGM: "
<<
errorString
(
derr
);
}
}
#endif // TRITON_ENABLE_METRICS_GPU
}
}
bool
Metrics
::
Enabled
()
{
auto
singleton
=
GetSingleton
();
return
singleton
->
metrics_enabled_
;
}
void
Metrics
::
EnableMetrics
()
{
auto
singleton
=
GetSingleton
();
singleton
->
metrics_enabled_
=
true
;
}
void
Metrics
::
EnableCacheMetrics
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
)
{
auto
singleton
=
GetSingleton
();
// Ensure thread-safe enabling of Cache Metrics
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton
->
metrics_enabling_
);
if
(
singleton
->
cache_metrics_enabled_
)
{
return
;
}
singleton
->
InitializeCacheMetrics
(
response_cache
);
singleton
->
cache_metrics_enabled_
=
true
;
}
void
Metrics
::
EnableGPUMetrics
()
{
auto
singleton
=
GetSingleton
();
// Ensure thread-safe enabling of GPU Metrics
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton
->
metrics_enabling_
);
if
(
singleton
->
gpu_metrics_enabled_
)
{
return
;
}
if
(
std
::
getenv
(
"TRITON_SERVER_CPU_ONLY"
)
==
nullptr
)
{
singleton
->
InitializeDcgmMetrics
();
}
singleton
->
gpu_metrics_enabled_
=
true
;
}
void
Metrics
::
EnableCpuMetrics
()
{
auto
singleton
=
GetSingleton
();
// Ensure thread-safe enabling of CPU Metrics
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton
->
metrics_enabling_
);
if
(
singleton
->
cpu_metrics_enabled_
)
{
return
;
}
singleton
->
InitializeCpuMetrics
();
singleton
->
cpu_metrics_enabled_
=
true
;
}
void
Metrics
::
SetMetricsInterval
(
uint64_t
metrics_interval_ms
)
{
auto
singleton
=
GetSingleton
();
singleton
->
metrics_interval_ms_
=
metrics_interval_ms
;
}
void
Metrics
::
StartPollingThreadSingleton
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
)
{
auto
singleton
=
GetSingleton
();
// Ensure thread-safe start of polling thread
std
::
lock_guard
<
std
::
mutex
>
lock
(
singleton
->
poll_thread_starting_
);
if
(
singleton
->
poll_thread_started_
)
{
return
;
}
// Start thread for polling cache/dcgm metrics
singleton
->
StartPollingThread
(
response_cache
);
// Toggle flag so this function is only executed once
singleton
->
poll_thread_started_
=
true
;
}
bool
Metrics
::
StartPollingThread
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
)
{
// Nothing to poll if no polling metrics enabled, don't spawn a thread
if
(
!
cache_metrics_enabled_
&&
!
gpu_metrics_enabled_
&&
!
cpu_metrics_enabled_
)
{
LOG_WARNING
<<
"No polling metrics (CPU, GPU, Cache) are enabled. Will not "
"poll for them."
;
return
false
;
}
poll_thread_exit_
.
store
(
false
);
// Start a separate thread for polling metrics at specified interval
poll_thread_
.
reset
(
new
std
::
thread
([
this
,
response_cache
]
{
// Thread will update metrics indefinitely until exit flag set
while
(
!
poll_thread_exit_
.
load
())
{
// Sleep for metric interval
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
metrics_interval_ms_
/
2
));
// Poll Response Cache metrics
if
(
cache_metrics_enabled_
&&
response_cache
!=
nullptr
)
{
PollCacheMetrics
(
response_cache
);
}
#ifdef TRITON_ENABLE_METRICS_GPU
// Poll DCGM GPU metrics
if
(
gpu_metrics_enabled_
&&
dcgm_metadata_
.
available_cuda_gpu_ids_
.
size
()
>
0
)
{
PollDcgmMetrics
();
}
#endif // TRITON_ENABLE_METRICS_GPU
#ifdef TRITON_ENABLE_METRICS_CPU
if
(
cpu_metrics_enabled_
)
{
PollCpuMetrics
();
}
#endif // TRITON_ENABLE_METRICS_CPU
}
}));
return
true
;
}
bool
Metrics
::
PollCacheMetrics
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
)
{
if
(
response_cache
==
nullptr
)
{
LOG_WARNING
<<
"error polling cache metrics, cache metrics will not be "
<<
"available: cache was nullptr"
;
return
false
;
}
// Update global cache metrics
cache_num_entries_global_
->
Set
(
response_cache
->
NumEntries
());
cache_num_lookups_global_
->
Set
(
response_cache
->
NumLookups
());
cache_num_hits_global_
->
Set
(
response_cache
->
NumHits
());
cache_num_misses_global_
->
Set
(
response_cache
->
NumMisses
());
cache_num_evictions_global_
->
Set
(
response_cache
->
NumEvictions
());
cache_lookup_duration_us_global_
->
Set
(
response_cache
->
TotalLookupLatencyNs
()
/
1000
);
cache_insertion_duration_us_global_
->
Set
(
response_cache
->
TotalInsertionLatencyNs
()
/
1000
);
cache_util_global_
->
Set
(
response_cache
->
TotalUtilization
());
return
true
;
}
#ifdef TRITON_ENABLE_METRICS_CPU
Status
Metrics
::
ParseCpuInfo
(
CpuInfo
&
info
)
{
#ifdef _WIN32
return
Status
(
Status
::
Code
::
INTERNAL
,
"CPU metrics not supported on Windows."
);
#else
std
::
ifstream
ifs
(
"/proc/stat"
);
if
(
!
ifs
.
good
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to open /proc/stat."
);
}
std
::
string
line
;
// Verify first line is aggregate cpu line
std
::
getline
(
ifs
,
line
);
if
(
line
.
rfind
(
"cpu "
,
0
)
==
std
::
string
::
npos
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to find aggregate CPU info in /proc/stat."
);
}
std
::
string
_
;
std
::
istringstream
iss
(
line
);
// Use _ to skip "cpu" at start of line
if
(
!
(
iss
>>
_
>>
info
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to parse aggregate CPU info in /proc/stat."
);
}
return
Status
::
Success
;
#endif // OS
}
Status
Metrics
::
ParseMemInfo
(
MemInfo
&
info
)
{
#ifdef _WIN32
return
Status
(
Status
::
Code
::
INTERNAL
,
"Memory metrics not supported on Windows."
);
#else
std
::
ifstream
ifs
(
"/proc/meminfo"
);
if
(
!
ifs
.
good
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to open /proc/meminfo."
);
}
std
::
string
line
;
constexpr
uint64_t
KB
=
1024
;
while
(
std
::
getline
(
ifs
,
line
))
{
std
::
istringstream
iss
(
line
);
std
::
string
name
;
uint64_t
value
=
0
;
if
(
iss
>>
name
>>
value
)
{
name
.
pop_back
();
info
[
name
]
=
value
*
KB
;
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Encountered error parsing /proc/meminfo."
);
}
}
if
(
info
.
find
(
"MemTotal"
)
==
info
.
end
()
||
info
.
find
(
"MemAvailable"
)
==
info
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to find desired values in /proc/meminfo."
);
}
if
(
info
[
"MemAvailable"
]
>
info
[
"MemTotal"
])
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Available bytes shouldn't be greater than Total bytes"
);
}
// "Used" memory can be defined in many different ways. While many
// older applications consider "used = total - (free + cached)", a more
// accurate measure of available memory "MemAvailable" was added,
// so we choose "used = total - available" for a more accurate measure.
// This may change in the future if not sufficient for most use cases.
// See https://stackoverflow.com/a/35019697.
info
[
"MemUsed"
]
=
info
[
"MemTotal"
]
-
info
[
"MemAvailable"
];
return
Status
::
Success
;
#endif // OS
}
double
Metrics
::
CpuUtilization
(
const
CpuInfo
&
info_new
,
const
CpuInfo
&
info_old
)
{
// Account for overflow
const
auto
wrap_sub
=
[](
uint64_t
a
,
uint64_t
b
)
{
return
(
a
>
b
)
?
(
a
-
b
)
:
0
;
};
uint64_t
util_diff
=
wrap_sub
(
info_new
.
user
,
info_old
.
user
)
+
wrap_sub
(
info_new
.
nice
,
info_old
.
nice
)
+
wrap_sub
(
info_new
.
system
,
info_old
.
system
)
+
wrap_sub
(
info_new
.
irq
,
info_old
.
irq
)
+
wrap_sub
(
info_new
.
softirq
,
info_old
.
softirq
)
+
wrap_sub
(
info_new
.
steal
,
info_old
.
steal
);
uint64_t
idle_diff
=
wrap_sub
(
info_new
.
idle
,
info_old
.
idle
)
+
wrap_sub
(
info_new
.
iowait
,
info_old
.
iowait
);
double
util_ratio
=
static_cast
<
double
>
(
util_diff
)
/
(
util_diff
+
idle_diff
);
return
util_ratio
;
}
#endif // TRITON_ENABLE_METRICS_CPU
bool
Metrics
::
PollCpuMetrics
()
{
#ifndef TRITON_ENABLE_METRICS_CPU
return
false
;
#else
// CPU Utilization
double
cpu_util
=
0.0
;
auto
cpu_info
=
CpuInfo
();
auto
status
=
ParseCpuInfo
(
cpu_info
);
if
(
status
.
IsOk
())
{
cpu_util
=
CpuUtilization
(
cpu_info
,
last_cpu_info_
);
last_cpu_info_
=
cpu_info
;
}
cpu_utilization_
->
Set
(
cpu_util
);
// [0.0, 1.0]
// RAM / Memory
double
mem_total_bytes
=
0.0
;
double
mem_used_bytes
=
0.0
;
auto
mem_info
=
MemInfo
();
status
=
ParseMemInfo
(
mem_info
);
if
(
status
.
IsOk
())
{
// MemTotal will usually not change over time, but if something
// goes wrong when querying memory, we can reflect that by updating.
mem_total_bytes
=
mem_info
[
"MemTotal"
];
mem_used_bytes
=
mem_info
[
"MemUsed"
];
}
cpu_memory_total_
->
Set
(
mem_total_bytes
);
cpu_memory_used_
->
Set
(
mem_used_bytes
);
return
true
;
#endif // TRITON_ENABLE_METRICS_CPU
}
bool
Metrics
::
PollDcgmMetrics
()
{
#ifndef TRITON_ENABLE_METRICS_GPU
return
false
;
#else
if
(
dcgm_metadata_
.
available_cuda_gpu_ids_
.
size
()
==
0
)
{
LOG_WARNING
<<
"error polling GPU metrics, GPU metrics will not be "
<<
"available: no available gpus to poll"
;
return
false
;
}
dcgmUpdateAllFields
(
dcgm_metadata_
.
dcgm_handle_
,
1
/* wait for update*/
);
for
(
unsigned
int
didx
=
0
;
didx
<
dcgm_metadata_
.
available_cuda_gpu_ids_
.
size
();
++
didx
)
{
uint32_t
cuda_id
=
dcgm_metadata_
.
available_cuda_gpu_ids_
[
didx
];
if
(
dcgm_metadata_
.
cuda_ids_to_dcgm_ids_
.
count
(
cuda_id
)
<=
0
)
{
LOG_WARNING
<<
"Cannot find DCGM id for CUDA id "
<<
cuda_id
;
continue
;
}
uint32_t
dcgm_id
=
dcgm_metadata_
.
cuda_ids_to_dcgm_ids_
.
at
(
cuda_id
);
dcgmFieldValue_v1
field_values
[
dcgm_metadata_
.
field_count_
];
dcgmReturn_t
dcgmerr
=
dcgmGetLatestValuesForFields
(
dcgm_metadata_
.
dcgm_handle_
,
dcgm_id
,
dcgm_metadata_
.
fields_
.
data
(),
dcgm_metadata_
.
field_count_
,
field_values
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
dcgm_metadata_
.
power_limit_fail_cnt_
[
didx
]
++
;
dcgm_metadata_
.
power_usage_fail_cnt_
[
didx
]
++
;
dcgm_metadata_
.
energy_fail_cnt_
[
didx
]
++
;
dcgm_metadata_
.
util_fail_cnt_
[
didx
]
++
;
dcgm_metadata_
.
mem_fail_cnt_
[
didx
]
++
;
LOG_WARNING
<<
"Unable to get field values for GPU ID "
<<
cuda_id
<<
": "
<<
errorString
(
dcgmerr
);
}
else
{
// Power limit
if
(
dcgm_metadata_
.
power_limit_fail_cnt_
[
didx
]
<
dcgm_metadata_
.
fail_threshold_
)
{
double
power_limit
=
field_values
[
0
].
value
.
dbl
;
if
((
field_values
[
0
].
status
==
DCGM_ST_OK
)
&&
(
!
DCGM_FP64_IS_BLANK
(
power_limit
)))
{
dcgm_metadata_
.
power_limit_fail_cnt_
[
didx
]
=
0
;
}
else
{
dcgm_metadata_
.
power_limit_fail_cnt_
[
didx
]
++
;
power_limit
=
0
;
dcgmReturn_t
status
=
dcgmReturn_t
(
field_values
[
0
].
status
);
LOG_WARNING
<<
"Unable to get power limit for GPU "
<<
cuda_id
<<
". Status:"
<<
errorString
(
status
)
<<
", value:"
<<
dcgmValueToErrorMessage
(
power_limit
);
}
gpu_power_limit_
[
didx
]
->
Set
(
power_limit
);
}
// Power usage
if
(
dcgm_metadata_
.
power_usage_fail_cnt_
[
didx
]
<
dcgm_metadata_
.
fail_threshold_
)
{
double
power_usage
=
field_values
[
1
].
value
.
dbl
;
if
((
field_values
[
1
].
status
==
DCGM_ST_OK
)
&&
(
!
DCGM_FP64_IS_BLANK
(
power_usage
)))
{
dcgm_metadata_
.
power_usage_fail_cnt_
[
didx
]
=
0
;
}
else
{
dcgm_metadata_
.
power_usage_fail_cnt_
[
didx
]
++
;
power_usage
=
0
;
dcgmReturn_t
status
=
dcgmReturn_t
(
field_values
[
1
].
status
);
LOG_WARNING
<<
"Unable to get power usage for GPU "
<<
cuda_id
<<
". Status:"
<<
errorString
(
status
)
<<
", value:"
<<
dcgmValueToErrorMessage
(
power_usage
);
}
gpu_power_usage_
[
didx
]
->
Set
(
power_usage
);
}
// Energy Consumption
if
(
dcgm_metadata_
.
energy_fail_cnt_
[
didx
]
<
dcgm_metadata_
.
fail_threshold_
)
{
int64_t
energy
=
field_values
[
2
].
value
.
i64
;
if
((
field_values
[
2
].
status
==
DCGM_ST_OK
)
&&
(
!
DCGM_INT64_IS_BLANK
(
energy
)))
{
dcgm_metadata_
.
energy_fail_cnt_
[
didx
]
=
0
;
if
(
dcgm_metadata_
.
last_energy_
[
didx
]
==
0
)
{
dcgm_metadata_
.
last_energy_
[
didx
]
=
energy
;
}
gpu_energy_consumption_
[
didx
]
->
Increment
(
(
double
)(
energy
-
dcgm_metadata_
.
last_energy_
[
didx
])
*
0.001
);
dcgm_metadata_
.
last_energy_
[
didx
]
=
energy
;
}
else
{
dcgm_metadata_
.
energy_fail_cnt_
[
didx
]
++
;
energy
=
0
;
dcgmReturn_t
status
=
dcgmReturn_t
(
field_values
[
2
].
status
);
LOG_WARNING
<<
"Unable to get energy consumption for "
<<
"GPU "
<<
cuda_id
<<
". Status:"
<<
errorString
(
status
)
<<
", value:"
<<
dcgmValueToErrorMessage
(
energy
);
}
}
// Utilization
if
(
dcgm_metadata_
.
util_fail_cnt_
[
didx
]
<
dcgm_metadata_
.
fail_threshold_
)
{
int64_t
util
=
field_values
[
3
].
value
.
i64
;
if
((
field_values
[
3
].
status
==
DCGM_ST_OK
)
&&
(
!
DCGM_INT64_IS_BLANK
(
util
)))
{
dcgm_metadata_
.
util_fail_cnt_
[
didx
]
=
0
;
}
else
{
dcgm_metadata_
.
util_fail_cnt_
[
didx
]
++
;
util
=
0
;
dcgmReturn_t
status
=
dcgmReturn_t
(
field_values
[
3
].
status
);
LOG_WARNING
<<
"Unable to get GPU utilization for GPU "
<<
cuda_id
<<
". Status:"
<<
errorString
(
status
)
<<
", value:"
<<
dcgmValueToErrorMessage
(
util
);
}
gpu_utilization_
[
didx
]
->
Set
((
double
)
util
*
0.01
);
}
// Memory Usage
if
(
dcgm_metadata_
.
mem_fail_cnt_
[
didx
]
<
dcgm_metadata_
.
fail_threshold_
)
{
int64_t
memory_used
=
field_values
[
4
].
value
.
i64
;
int64_t
memory_total
=
field_values
[
5
].
value
.
i64
;
if
((
field_values
[
4
].
status
==
DCGM_ST_OK
)
&&
(
!
DCGM_INT64_IS_BLANK
(
memory_used
))
&&
(
field_values
[
5
].
status
==
DCGM_ST_OK
)
&&
(
!
DCGM_INT64_IS_BLANK
(
memory_total
)))
{
dcgm_metadata_
.
mem_fail_cnt_
[
didx
]
=
0
;
}
else
{
memory_total
=
0
;
memory_used
=
0
;
dcgm_metadata_
.
mem_fail_cnt_
[
didx
]
++
;
dcgmReturn_t
usageStatus
=
dcgmReturn_t
(
field_values
[
4
].
status
);
dcgmReturn_t
memoryTotaltatus
=
dcgmReturn_t
(
field_values
[
5
].
status
);
LOG_WARNING
<<
"Unable to get memory usage for GPU "
<<
cuda_id
<<
". Memory usage status:"
<<
errorString
(
usageStatus
)
<<
", value:"
<<
dcgmValueToErrorMessage
(
memory_used
)
<<
". Memory total status:"
<<
errorString
(
memoryTotaltatus
)
<<
", value:"
<<
dcgmValueToErrorMessage
(
memory_total
);
}
gpu_memory_total_
[
didx
]
->
Set
(
memory_total
*
1024
*
1024
);
// bytes
gpu_memory_used_
[
didx
]
->
Set
(
memory_used
*
1024
*
1024
);
// bytes
}
}
}
return
true
;
#endif // TRITON_ENABLE_METRICS_GPU
}
bool
Metrics
::
InitializeCacheMetrics
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
)
{
if
(
response_cache
==
nullptr
)
{
LOG_WARNING
<<
"error initializing cache metrics, cache metrics will not be "
<<
"available: cache was nullptr"
;
return
false
;
}
const
std
::
map
<
std
::
string
,
std
::
string
>
cache_labels
;
cache_num_entries_global_
=
&
cache_num_entries_family_
.
Add
(
cache_labels
);
cache_num_lookups_global_
=
&
cache_num_lookups_family_
.
Add
(
cache_labels
);
cache_num_hits_global_
=
&
cache_num_hits_family_
.
Add
(
cache_labels
);
cache_num_misses_global_
=
&
cache_num_misses_family_
.
Add
(
cache_labels
);
cache_num_evictions_global_
=
&
cache_num_evictions_family_
.
Add
(
cache_labels
);
cache_lookup_duration_us_global_
=
&
cache_lookup_duration_us_family_
.
Add
(
cache_labels
);
cache_insertion_duration_us_global_
=
&
cache_insertion_duration_us_family_
.
Add
(
cache_labels
);
cache_util_global_
=
&
cache_util_family_
.
Add
(
cache_labels
);
LOG_INFO
<<
"Collecting Response Cache metrics"
;
return
true
;
}
bool
Metrics
::
InitializeCpuMetrics
()
{
#ifndef TRITON_ENABLE_METRICS_CPU
return
false
;
#else
const
std
::
map
<
std
::
string
,
std
::
string
>
cpu_labels
;
cpu_utilization_
=
&
cpu_utilization_family_
.
Add
(
cpu_labels
);
cpu_memory_total_
=
&
cpu_memory_total_family_
.
Add
(
cpu_labels
);
cpu_memory_used_
=
&
cpu_memory_used_family_
.
Add
(
cpu_labels
);
// Get baseline CPU info for future comparisons
last_cpu_info_
=
CpuInfo
();
auto
status
=
ParseCpuInfo
(
last_cpu_info_
);
if
(
!
status
.
IsOk
())
{
LOG_WARNING
<<
"error initializing CPU metrics, CPU utilization may not "
"be available: "
<<
status
.
Message
();
return
false
;
}
// Verify memory metrics can be parsed
auto
mem_info
=
MemInfo
();
status
=
ParseMemInfo
(
mem_info
);
if
(
!
status
.
IsOk
())
{
LOG_WARNING
<<
"error initializing CPU metrics, CPU memory metrics may not "
"be available: "
<<
status
.
Message
();
return
false
;
}
LOG_INFO
<<
"Collecting CPU metrics"
;
return
true
;
#endif // TRITON_ENABLE_METRICS_CPU
}
bool
Metrics
::
InitializeDcgmMetrics
()
{
#ifndef TRITON_ENABLE_METRICS_GPU
return
false
;
#else
dcgmReturn_t
dcgmerr
=
dcgmInit
();
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"error initializing DCGM, GPU metrics will not be "
<<
"available: "
<<
errorString
(
dcgmerr
);
return
false
;
}
if
(
dcgm_metadata_
.
standalone_
)
{
char
hostIpAddress
[
16
]
=
{
0
};
std
::
string
ipAddress
=
"127.0.0.1"
;
strncpy
(
hostIpAddress
,
ipAddress
.
c_str
(),
15
);
dcgmerr
=
dcgmConnect
(
hostIpAddress
,
&
dcgm_metadata_
.
dcgm_handle_
);
}
else
{
dcgmerr
=
dcgmStartEmbedded
(
DCGM_OPERATION_MODE_MANUAL
,
&
dcgm_metadata_
.
dcgm_handle_
);
}
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"DCGM unable to start: "
<<
errorString
(
dcgmerr
);
return
false
;
}
else
{
// Set this flag to signal DCGM cleanup in destructor
dcgm_metadata_
.
dcgm_initialized_
=
true
;
}
if
(
dcgm_metadata_
.
standalone_
)
{
dcgmerr
=
dcgmUpdateAllFields
(
dcgm_metadata_
.
dcgm_handle_
,
1
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"DCGM unable to update all fields, GPU metrics will "
"not be available: "
<<
errorString
(
dcgmerr
);
return
false
;
}
}
unsigned
int
dcgm_gpu_ids
[
DCGM_MAX_NUM_DEVICES
];
int
dcgm_gpu_count
;
dcgmerr
=
dcgmGetAllDevices
(
dcgm_metadata_
.
dcgm_handle_
,
dcgm_gpu_ids
,
&
dcgm_gpu_count
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"DCGM unable to get device info and count, GPU "
"metrics will not be available: "
<<
errorString
(
dcgmerr
);
return
false
;
}
// Get PCI Bus ID to DCGM device Id map.
// Some devices may have problems using DCGM API and
// these devices needs to be ignored.
std
::
map
<
std
::
string
,
size_t
>
pci_bus_id_to_dcgm_id
;
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
string
>
>
pci_bus_id_to_gpu_labels
;
std
::
map
<
std
::
string
,
std
::
string
>
pci_bus_id_to_device_name
;
dcgmDeviceAttributes_t
gpu_attributes
[
DCGM_MAX_NUM_DEVICES
];
for
(
int
i
=
0
;
i
<
dcgm_gpu_count
;
i
++
)
{
gpu_attributes
[
i
].
version
=
dcgmDeviceAttributes_version
;
dcgmerr
=
dcgmGetDeviceAttributes
(
dcgm_metadata_
.
dcgm_handle_
,
dcgm_gpu_ids
[
i
],
&
gpu_attributes
[
i
]);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"DCGM unable to get device properties for DCGM device "
<<
dcgm_gpu_ids
[
i
]
<<
", GPU metrics will not be available for this device: "
<<
errorString
(
dcgmerr
);
}
else
{
std
::
string
pciBusId
=
gpu_attributes
[
i
].
identifiers
.
pciBusId
;
pci_bus_id_to_dcgm_id
[
pciBusId
]
=
i
;
pci_bus_id_to_device_name
[
pciBusId
]
=
std
::
string
(
gpu_attributes
[
i
].
identifiers
.
deviceName
);
std
::
map
<
std
::
string
,
std
::
string
>
gpu_labels
;
gpu_labels
.
insert
(
std
::
map
<
std
::
string
,
std
::
string
>::
value_type
(
kMetricsLabelGpuUuid
,
std
::
string
(
gpu_attributes
[
i
].
identifiers
.
uuid
)));
pci_bus_id_to_gpu_labels
[
pciBusId
]
=
gpu_labels
;
}
}
// Get CUDA-visible PCI Bus Ids and get DCGM metrics for each CUDA-visible GPU
int
cuda_gpu_count
;
cudaError_t
cudaerr
=
cudaGetDeviceCount
(
&
cuda_gpu_count
);
if
(
cudaerr
!=
cudaSuccess
)
{
LOG_WARNING
<<
"Cannot get CUDA device count, GPU metrics will not be available"
;
return
false
;
}
for
(
int
i
=
0
;
i
<
cuda_gpu_count
;
++
i
)
{
std
::
string
pci_bus_id
=
"0000"
;
// pad 0's for uniformity
char
pcibusid_str
[
64
];
cudaerr
=
cudaDeviceGetPCIBusId
(
pcibusid_str
,
sizeof
(
pcibusid_str
)
-
1
,
i
);
if
(
cudaerr
==
cudaSuccess
)
{
pci_bus_id
.
append
(
pcibusid_str
);
if
(
pci_bus_id_to_dcgm_id
.
count
(
pci_bus_id
)
<=
0
)
{
LOG_INFO
<<
"Skipping GPU:"
<<
i
<<
" since it's not CUDA enabled. This should never happen!"
;
continue
;
}
// Filter out CUDA visible GPUs from GPUs found by DCGM
LOG_INFO
<<
"Collecting metrics for GPU "
<<
i
<<
": "
<<
pci_bus_id_to_device_name
[
pci_bus_id
];
auto
&
gpu_labels
=
pci_bus_id_to_gpu_labels
[
pci_bus_id
];
gpu_utilization_
.
push_back
(
&
gpu_utilization_family_
.
Add
(
gpu_labels
));
gpu_memory_total_
.
push_back
(
&
gpu_memory_total_family_
.
Add
(
gpu_labels
));
gpu_memory_used_
.
push_back
(
&
gpu_memory_used_family_
.
Add
(
gpu_labels
));
gpu_power_usage_
.
push_back
(
&
gpu_power_usage_family_
.
Add
(
gpu_labels
));
gpu_power_limit_
.
push_back
(
&
gpu_power_limit_family_
.
Add
(
gpu_labels
));
gpu_energy_consumption_
.
push_back
(
&
gpu_energy_consumption_family_
.
Add
(
gpu_labels
));
uint32_t
dcgm_id
=
pci_bus_id_to_dcgm_id
[
pci_bus_id
];
dcgm_metadata_
.
cuda_ids_to_dcgm_ids_
[
i
]
=
dcgm_id
;
dcgm_metadata_
.
available_cuda_gpu_ids_
.
emplace_back
(
i
);
}
else
{
LOG_WARNING
<<
"GPU metrics will not be available for device:"
<<
i
;
}
}
// create a gpu group
char
groupName
[]
=
"dcgm_group"
;
dcgmerr
=
dcgmGroupCreate
(
dcgm_metadata_
.
dcgm_handle_
,
DCGM_GROUP_DEFAULT
,
groupName
,
&
dcgm_metadata_
.
groupId_
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"Cannot make GPU group: "
<<
errorString
(
dcgmerr
);
}
// Initialize tracking vectors
for
(
unsigned
int
didx
=
0
;
didx
<
dcgm_metadata_
.
available_cuda_gpu_ids_
.
size
();
++
didx
)
{
dcgm_metadata_
.
power_limit_fail_cnt_
.
push_back
(
0
);
dcgm_metadata_
.
power_usage_fail_cnt_
.
push_back
(
0
);
dcgm_metadata_
.
energy_fail_cnt_
.
push_back
(
0
);
dcgm_metadata_
.
util_fail_cnt_
.
push_back
(
0
);
dcgm_metadata_
.
mem_fail_cnt_
.
push_back
(
0
);
dcgm_metadata_
.
last_energy_
.
push_back
(
0
);
}
// Number of fields for DCGM to use from fields_ below
dcgm_metadata_
.
field_count_
=
6
;
unsigned
short
util_flag
=
dcgm_metadata_
.
standalone_
?
DCGM_FI_PROF_GR_ENGINE_ACTIVE
:
DCGM_FI_DEV_GPU_UTIL
;
dcgm_metadata_
.
fields_
=
{
DCGM_FI_DEV_POWER_MGMT_LIMIT
,
// power limit, watts
DCGM_FI_DEV_POWER_USAGE
,
// power usage, watts
DCGM_FI_DEV_TOTAL_ENERGY_CONSUMPTION
,
// Total energy consumption, mJ
util_flag
,
// util ratio, 1 = 1%
DCGM_FI_DEV_FB_USED
,
// Frame buffer used, MiB
DCGM_FI_DEV_FB_TOTAL
,
// Frame buffer used, MiB
};
char
fieldName
[]
=
"field_group"
;
dcgmFieldGrp_t
fieldGroupId
;
dcgmerr
=
dcgmFieldGroupCreate
(
dcgm_metadata_
.
dcgm_handle_
,
dcgm_metadata_
.
field_count_
,
dcgm_metadata_
.
fields_
.
data
(),
fieldName
,
&
fieldGroupId
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"Cannot make field group: "
<<
errorString
(
dcgmerr
);
}
dcgmerr
=
dcgmWatchFields
(
dcgm_metadata_
.
dcgm_handle_
,
dcgm_metadata_
.
groupId_
,
fieldGroupId
,
metrics_interval_ms_
*
1000
/*update period, usec*/
,
5.0
/*maxKeepAge, sec*/
,
5
/*maxKeepSamples*/
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_WARNING
<<
"Cannot start watching fields: "
<<
errorString
(
dcgmerr
);
return
false
;
}
return
true
;
#endif // TRITON_ENABLE_METRICS_GPU
}
#ifdef TRITON_ENABLE_METRICS_GPU
std
::
string
Metrics
::
dcgmValueToErrorMessage
(
double
val
)
{
if
(
DCGM_FP64_IS_BLANK
(
val
))
{
if
(
val
==
DCGM_FP64_BLANK
)
{
return
"Not Specified"
;
}
else
if
(
val
==
DCGM_FP64_NOT_FOUND
)
{
return
"Not Found"
;
}
else
if
(
val
==
DCGM_FP64_NOT_SUPPORTED
)
{
return
"Not Supported"
;
}
else
if
(
val
==
DCGM_FP64_NOT_PERMISSIONED
)
{
return
"Insf. Permission"
;
}
else
{
return
"Unknown"
;
}
}
else
{
return
std
::
to_string
(
val
);
}
}
std
::
string
Metrics
::
dcgmValueToErrorMessage
(
int64_t
val
)
{
if
(
DCGM_INT64_IS_BLANK
(
val
))
{
switch
(
val
)
{
case
DCGM_INT64_BLANK
:
return
"Not Specified"
;
case
DCGM_INT64_NOT_FOUND
:
return
"Not Found"
;
case
DCGM_INT64_NOT_SUPPORTED
:
return
"Not Supported"
;
case
DCGM_INT64_NOT_PERMISSIONED
:
return
"Insf. Permission"
;
default:
return
"Unknown"
;
}
}
else
{
return
std
::
to_string
(
val
);
}
}
#endif // TRITON_ENABLE_METRICS_GPU
bool
Metrics
::
UUIDForCudaDevice
(
int
cuda_device
,
std
::
string
*
uuid
)
{
// If metrics were not initialized then just silently fail since
// with DCGM we can't get the CUDA device (and not worth doing
// anyway since metrics aren't being reported).
auto
singleton
=
GetSingleton
();
if
(
!
singleton
->
gpu_metrics_enabled_
)
{
return
false
;
}
// If GPU metrics is not enabled just silently fail.
#ifndef TRITON_ENABLE_METRICS_GPU
return
false
;
#else
dcgmDeviceAttributes_t
gpu_attributes
;
gpu_attributes
.
version
=
dcgmDeviceAttributes_version
;
dcgmReturn_t
dcgmerr
=
dcgmGetDeviceAttributes
(
singleton
->
dcgm_metadata_
.
dcgm_handle_
,
cuda_device
,
&
gpu_attributes
);
if
(
dcgmerr
!=
DCGM_ST_OK
)
{
LOG_ERROR
<<
"Unable to get device UUID: "
<<
errorString
(
dcgmerr
);
return
false
;
}
*
uuid
=
gpu_attributes
.
identifiers
.
uuid
;
return
true
;
#endif // TRITON_ENABLE_METRICS_GPU
}
std
::
shared_ptr
<
prometheus
::
Registry
>
Metrics
::
GetRegistry
()
{
auto
singleton
=
Metrics
::
GetSingleton
();
return
singleton
->
registry_
;
}
const
std
::
string
Metrics
::
SerializedMetrics
()
{
auto
singleton
=
Metrics
::
GetSingleton
();
return
singleton
->
serializer_
->
Serialize
(
singleton
->
registry_
.
get
()
->
Collect
());
}
Metrics
*
Metrics
::
GetSingleton
()
{
static
Metrics
singleton
;
return
&
singleton
;
}
}}
// namespace triton::core
#endif // TRITON_ENABLE_METRICS
3rdparty/core-r22.12/src/metrics.h
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#ifdef TRITON_ENABLE_METRICS
#include <atomic>
#include <mutex>
#include <thread>
#include "prometheus/counter.h"
#include "prometheus/gauge.h"
#include "prometheus/registry.h"
#include "prometheus/serializer.h"
#include "prometheus/text_serializer.h"
#include "response_cache.h"
#ifdef TRITON_ENABLE_METRICS_GPU
#include <dcgm_agent.h>
#endif // TRITON_ENABLE_METRICS_GPU
namespace
triton
{
namespace
core
{
#ifdef TRITON_ENABLE_METRICS_CPU
using
MemInfo
=
std
::
unordered_map
<
std
::
string
,
uint64_t
>
;
// References:
// - htop source: https://stackoverflow.com/a/23376195
// - Linux docs: https://www.kernel.org/doc/Documentation/filesystems/proc.txt
// guest/guestnice values are counted in user/nice so we skip parsing them
struct
CpuInfo
{
uint64_t
user
=
0
;
// normal processes executing in user mode
uint64_t
nice
=
0
;
// niced processes executing in user mode
uint64_t
system
=
0
;
// processes executing in kernel mode
uint64_t
idle
=
0
;
// twiddling thumbs
uint64_t
iowait
=
0
;
// waiting for I/O to complete
uint64_t
irq
=
0
;
// servicing interrupts
uint64_t
softirq
=
0
;
// servicing softirqs
uint64_t
steal
=
0
;
// involuntary wait
};
inline
std
::
istream
&
operator
>>
(
std
::
istream
&
is
,
CpuInfo
&
info
)
{
is
>>
info
.
user
>>
info
.
nice
>>
info
.
system
>>
info
.
idle
>>
info
.
iowait
>>
info
.
irq
>>
info
.
softirq
>>
info
.
steal
;
return
is
;
}
#endif // TRITON_ENABLE_METRICS_CPU
#ifdef TRITON_ENABLE_METRICS_GPU
struct
DcgmMetadata
{
// DCGM handles for initialization and destruction
dcgmHandle_t
dcgm_handle_
=
0
;
dcgmGpuGrp_t
groupId_
=
0
;
// DCGM Flags
bool
standalone_
=
false
;
// DCGM Fields
size_t
field_count_
=
0
;
std
::
vector
<
unsigned
short
>
fields_
;
// GPU Device Mapping
std
::
map
<
uint32_t
,
uint32_t
>
cuda_ids_to_dcgm_ids_
;
std
::
vector
<
uint32_t
>
available_cuda_gpu_ids_
;
// Stop attempting metrics if they fail multiple consecutive
// times for a device.
const
int
fail_threshold_
=
3
;
// DCGM Failure Tracking
std
::
vector
<
int
>
power_limit_fail_cnt_
;
std
::
vector
<
int
>
power_usage_fail_cnt_
;
std
::
vector
<
int
>
energy_fail_cnt_
;
std
::
vector
<
int
>
util_fail_cnt_
;
std
::
vector
<
int
>
mem_fail_cnt_
;
// DCGM Energy Tracking
std
::
vector
<
unsigned
long
long
>
last_energy_
;
// Track if DCGM handle initialized successfully
bool
dcgm_initialized_
=
false
;
};
#endif // TRITON_ENABLE_METRICS_GPU
class
Metrics
{
public:
// Return the hash value of the labels
static
size_t
HashLabels
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
labels
);
// Are metrics enabled?
static
bool
Enabled
();
// Enable reporting of metrics
static
void
EnableMetrics
();
// Enable reporting of GPU metrics
static
void
EnableGPUMetrics
();
// Enable reporting of CPU metrics
static
void
EnableCpuMetrics
();
// Enable reporting of Cache metrics
static
void
EnableCacheMetrics
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
);
// Start a thread for polling enabled metrics if any
static
void
StartPollingThreadSingleton
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
);
// Set the time interval in secs at which metrics are collected
static
void
SetMetricsInterval
(
uint64_t
metrics_interval_ms
);
// Get the prometheus registry
static
std
::
shared_ptr
<
prometheus
::
Registry
>
GetRegistry
();
// Get serialized metrics
static
const
std
::
string
SerializedMetrics
();
// Get the UUID for a CUDA device. Return true and initialize 'uuid'
// if a UUID is found, return false if a UUID cannot be returned.
static
bool
UUIDForCudaDevice
(
int
cuda_device
,
std
::
string
*
uuid
);
// Metric family counting successful inference requests
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceSuccess
()
{
return
GetSingleton
()
->
inf_success_family_
;
}
// Metric family counting failed inference requests
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceFailure
()
{
return
GetSingleton
()
->
inf_failure_family_
;
}
// Metric family counting inferences performed, where a batch-size
// 'n' inference request is counted as 'n' inferences
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceCount
()
{
return
GetSingleton
()
->
inf_count_family_
;
}
// Metric family counting inferences performed, where a batch-size
// 'n' inference request is counted as 'n' inferences
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceExecutionCount
()
{
return
GetSingleton
()
->
inf_count_exec_family_
;
}
// Metric family of cumulative inference request duration, in
// microseconds
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceRequestDuration
()
{
return
GetSingleton
()
->
inf_request_duration_us_family_
;
}
// Metric family of cumulative inference queuing duration, in
// microseconds
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceQueueDuration
()
{
return
GetSingleton
()
->
inf_queue_duration_us_family_
;
}
// Metric family of cumulative inference compute durations, in
// microseconds
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceComputeInputDuration
()
{
return
GetSingleton
()
->
inf_compute_input_duration_us_family_
;
}
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceComputeInferDuration
()
{
return
GetSingleton
()
->
inf_compute_infer_duration_us_family_
;
}
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyInferenceComputeOutputDuration
()
{
return
GetSingleton
()
->
inf_compute_output_duration_us_family_
;
}
// Metric families of per-model response cache metrics
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyCacheHitCount
()
{
return
GetSingleton
()
->
cache_num_hits_model_family_
;
}
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyCacheHitLookupDuration
()
{
return
GetSingleton
()
->
cache_hit_lookup_duration_us_model_family_
;
}
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyCacheMissCount
()
{
return
GetSingleton
()
->
cache_num_misses_model_family_
;
}
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyCacheMissLookupDuration
()
{
return
GetSingleton
()
->
cache_miss_lookup_duration_us_model_family_
;
}
static
prometheus
::
Family
<
prometheus
::
Counter
>&
FamilyCacheMissInsertionDuration
()
{
return
GetSingleton
()
->
cache_miss_insertion_duration_us_model_family_
;
}
private:
Metrics
();
virtual
~
Metrics
();
static
Metrics
*
GetSingleton
();
bool
InitializeDcgmMetrics
();
bool
InitializeCpuMetrics
();
bool
InitializeCacheMetrics
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
);
bool
StartPollingThread
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
);
bool
PollCacheMetrics
(
std
::
shared_ptr
<
RequestResponseCache
>
response_cache
);
bool
PollDcgmMetrics
();
bool
PollCpuMetrics
();
std
::
string
dcgmValueToErrorMessage
(
double
val
);
std
::
string
dcgmValueToErrorMessage
(
int64_t
val
);
std
::
shared_ptr
<
prometheus
::
Registry
>
registry_
;
std
::
unique_ptr
<
prometheus
::
Serializer
>
serializer_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_success_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_failure_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_count_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_count_exec_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_request_duration_us_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_queue_duration_us_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_compute_input_duration_us_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_compute_infer_duration_us_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
inf_compute_output_duration_us_family_
;
// Global Response Cache metrics
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_num_entries_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_num_lookups_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_num_hits_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_num_misses_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_num_evictions_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_lookup_duration_us_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_insertion_duration_us_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cache_util_family_
;
// Gauges for Global Response Cache metrics
prometheus
::
Gauge
*
cache_num_entries_global_
;
prometheus
::
Gauge
*
cache_num_lookups_global_
;
prometheus
::
Gauge
*
cache_num_hits_global_
;
prometheus
::
Gauge
*
cache_num_misses_global_
;
prometheus
::
Gauge
*
cache_num_evictions_global_
;
prometheus
::
Gauge
*
cache_lookup_duration_us_global_
;
prometheus
::
Gauge
*
cache_insertion_duration_us_global_
;
prometheus
::
Gauge
*
cache_util_global_
;
// Per-model Response Cache metrics
prometheus
::
Family
<
prometheus
::
Counter
>&
cache_num_hits_model_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
cache_hit_lookup_duration_us_model_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
cache_num_misses_model_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
cache_miss_lookup_duration_us_model_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
cache_miss_insertion_duration_us_model_family_
;
#ifdef TRITON_ENABLE_METRICS_GPU
prometheus
::
Family
<
prometheus
::
Gauge
>&
gpu_utilization_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
gpu_memory_total_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
gpu_memory_used_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
gpu_power_usage_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
gpu_power_limit_family_
;
prometheus
::
Family
<
prometheus
::
Counter
>&
gpu_energy_consumption_family_
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_utilization_
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_memory_total_
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_memory_used_
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_power_usage_
;
std
::
vector
<
prometheus
::
Gauge
*>
gpu_power_limit_
;
std
::
vector
<
prometheus
::
Counter
*>
gpu_energy_consumption_
;
DcgmMetadata
dcgm_metadata_
;
#endif // TRITON_ENABLE_METRICS_GPU
#ifdef TRITON_ENABLE_METRICS_CPU
// Parses "/proc/meminfo" for metrics, currently only supported on Linux.
Status
ParseMemInfo
(
MemInfo
&
info
);
// Parses "/proc/stat" for metrics, currently only supported on Linux.
Status
ParseCpuInfo
(
CpuInfo
&
info
);
// Computes CPU utilization between "info_new" and "info_old" values
double
CpuUtilization
(
const
CpuInfo
&
info_new
,
const
CpuInfo
&
info_old
);
prometheus
::
Family
<
prometheus
::
Gauge
>&
cpu_utilization_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cpu_memory_total_family_
;
prometheus
::
Family
<
prometheus
::
Gauge
>&
cpu_memory_used_family_
;
prometheus
::
Gauge
*
cpu_utilization_
;
prometheus
::
Gauge
*
cpu_memory_total_
;
prometheus
::
Gauge
*
cpu_memory_used_
;
CpuInfo
last_cpu_info_
;
#endif // TRITON_ENABLE_METRICS_CPU
// Thread for polling cache/gpu metrics periodically
std
::
unique_ptr
<
std
::
thread
>
poll_thread_
;
std
::
atomic
<
bool
>
poll_thread_exit_
;
bool
metrics_enabled_
;
bool
gpu_metrics_enabled_
;
bool
cpu_metrics_enabled_
;
bool
cache_metrics_enabled_
;
bool
poll_thread_started_
;
std
::
mutex
metrics_enabling_
;
std
::
mutex
poll_thread_starting_
;
uint64_t
metrics_interval_ms_
;
};
}}
// namespace triton::core
#endif // TRITON_ENABLE_METRICS
3rdparty/core-r22.12/src/model.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "model.h"
#include <chrono>
#include <future>
#include "constants.h"
#include "filesystem.h"
#include "infer_request.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
Status
Model
::
GetInput
(
const
std
::
string
&
name
,
const
inference
::
ModelInput
**
input
)
const
{
const
auto
itr
=
input_map_
.
find
(
name
);
if
(
itr
==
input_map_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected inference input '"
+
name
+
"' for model '"
+
Name
()
+
"'"
);
}
*
input
=
&
itr
->
second
;
return
Status
::
Success
;
}
Status
Model
::
GetOutput
(
const
std
::
string
&
name
,
const
inference
::
ModelOutput
**
output
)
const
{
const
auto
itr
=
output_map_
.
find
(
name
);
if
(
itr
==
output_map_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected inference output '"
+
name
+
"' for model '"
+
Name
()
+
"'"
);
}
*
output
=
&
itr
->
second
;
return
Status
::
Success
;
}
Status
Model
::
SetModelConfig
(
const
inference
::
ModelConfig
&
config
)
{
config_
=
config
;
set_model_config_
=
true
;
return
Status
::
Success
;
}
Status
Model
::
SetScheduler
(
std
::
unique_ptr
<
Scheduler
>
scheduler
)
{
if
(
scheduler_
!=
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Attempt to change scheduler not allowed"
);
}
scheduler_
=
std
::
move
(
scheduler
);
return
Status
::
Success
;
}
Status
Model
::
Init
(
const
bool
is_config_provided
)
{
if
(
!
set_model_config_
&&
!
is_config_provided
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"model configuration is not provided for model '"
+
Name
()
+
"'"
);
}
RETURN_IF_ERROR
(
ValidateModelConfig
(
config_
,
min_compute_capability_
));
RETURN_IF_ERROR
(
ValidateModelIOConfig
(
config_
));
// Initialize the input map
for
(
const
auto
&
io
:
config_
.
input
())
{
input_map_
.
insert
(
std
::
make_pair
(
io
.
name
(),
io
));
if
(
!
io
.
optional
())
{
++
required_input_count_
;
}
}
// Initialize the output map and label provider for each output
label_provider_
=
std
::
make_shared
<
LabelProvider
>
();
for
(
const
auto
&
io
:
config_
.
output
())
{
output_map_
.
insert
(
std
::
make_pair
(
io
.
name
(),
io
));
if
(
!
io
.
label_filename
().
empty
())
{
const
auto
label_path
=
JoinPath
({
model_dir_
,
io
.
label_filename
()});
RETURN_IF_ERROR
(
label_provider_
->
AddLabels
(
io
.
name
(),
label_path
));
}
}
if
(
config_
.
has_dynamic_batching
())
{
default_priority_level_
=
config_
.
dynamic_batching
().
default_priority_level
();
max_priority_level_
=
config_
.
dynamic_batching
().
priority_levels
();
}
else
if
(
config_
.
has_ensemble_scheduling
())
{
// For ensemble, allow any priority level to pass through
default_priority_level_
=
0
;
max_priority_level_
=
UINT32_MAX
;
}
else
{
default_priority_level_
=
0
;
max_priority_level_
=
0
;
}
return
Status
::
Success
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model.h
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "infer_stats.h"
#include "label_provider.h"
#include "model_config.pb.h"
#include "scheduler.h"
#include "status.h"
namespace
triton
{
namespace
core
{
class
InferenceRequest
;
//
// Interface for models that handle inference requests.
//
class
Model
{
public:
explicit
Model
(
const
double
min_compute_capability
,
const
std
::
string
&
model_dir
,
const
int64_t
version
,
const
inference
::
ModelConfig
&
config
)
:
config_
(
config
),
min_compute_capability_
(
min_compute_capability
),
version_
(
version
),
required_input_count_
(
0
),
model_dir_
(
model_dir
),
set_model_config_
(
false
)
{
}
virtual
~
Model
()
{}
// Get the name of model being served.
const
std
::
string
&
Name
()
const
{
return
config_
.
name
();
}
// Get the version of model being served.
int64_t
Version
()
const
{
return
version_
;
}
// Get the configuration of model being served.
const
inference
::
ModelConfig
&
Config
()
const
{
return
config_
;
}
// Get the number of required inputs
size_t
RequiredInputCount
()
const
{
return
required_input_count_
;
}
// Get the stats collector for the model being served.
InferenceStatsAggregator
*
MutableStatsAggregator
()
{
return
&
stats_aggregator_
;
}
const
InferenceStatsAggregator
&
StatsAggregator
()
const
{
return
stats_aggregator_
;
}
// Get the model configuration for a named input.
Status
GetInput
(
const
std
::
string
&
name
,
const
inference
::
ModelInput
**
input
)
const
;
// Get the model configuration for a named output.
Status
GetOutput
(
const
std
::
string
&
name
,
const
inference
::
ModelOutput
**
output
)
const
;
// Get a label provider for the model.
const
std
::
shared_ptr
<
LabelProvider
>&
GetLabelProvider
()
const
{
return
label_provider_
;
}
// Initialize the instance for Triton core usage
Status
Init
(
const
bool
is_config_provided
);
// Enqueue a request for execution. If Status::Success is returned
// then the model has taken ownership of the request object and so
// 'request' will be nullptr. If non-success is returned then the
// caller still retains ownership of 'request'.
Status
Enqueue
(
std
::
unique_ptr
<
InferenceRequest
>&
request
)
{
return
scheduler_
->
Enqueue
(
request
);
}
// Return the number of in-flight inferences.
size_t
InflightInferenceCount
()
{
return
scheduler_
->
InflightInferenceCount
();
}
// Stop processing future requests unless they are considered as in-flight.
void
Stop
()
{
scheduler_
->
Stop
();
}
uint32_t
DefaultPriorityLevel
()
const
{
return
default_priority_level_
;
}
uint32_t
MaxPriorityLevel
()
const
{
return
max_priority_level_
;
}
protected:
// Set the configuration of the model being served.
Status
SetModelConfig
(
const
inference
::
ModelConfig
&
config
);
// Explicitly set the scheduler to use for inference requests to the
// model. The scheduler can only be set once for a model.
Status
SetScheduler
(
std
::
unique_ptr
<
Scheduler
>
scheduler
);
// The scheduler to use for this model.
std
::
unique_ptr
<
Scheduler
>
scheduler_
;
// Configuration of the model.
inference
::
ModelConfig
config_
;
private:
// The minimum supported CUDA compute capability.
const
double
min_compute_capability_
;
// Version of the model.
int64_t
version_
;
// The stats collector for the model.
InferenceStatsAggregator
stats_aggregator_
;
// Label provider for this model.
std
::
shared_ptr
<
LabelProvider
>
label_provider_
;
size_t
required_input_count_
;
// Map from input name to the model configuration for that input.
std
::
unordered_map
<
std
::
string
,
inference
::
ModelInput
>
input_map_
;
// Map from output name to the model configuration for that output.
std
::
unordered_map
<
std
::
string
,
inference
::
ModelOutput
>
output_map_
;
// Path to model
std
::
string
model_dir_
;
// The default priority level for the model.
uint32_t
default_priority_level_
;
// The largest priority value for the model.
uint32_t
max_priority_level_
;
// Whether or not model config has been set.
bool
set_model_config_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_config_cuda.cc
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "model_config_cuda.h"
#include <cuda_runtime_api.h>
namespace
triton
{
namespace
core
{
int
GetCudaStreamPriority
(
inference
::
ModelOptimizationPolicy
::
ModelPriority
priority
)
{
// Default priority is 0
int
cuda_stream_priority
=
0
;
int
min
,
max
;
cudaError_t
cuerr
=
cudaDeviceGetStreamPriorityRange
(
&
min
,
&
max
);
if
((
cuerr
!=
cudaErrorNoDevice
)
&&
(
cuerr
!=
cudaSuccess
))
{
return
0
;
}
switch
(
priority
)
{
case
inference
::
ModelOptimizationPolicy
::
PRIORITY_MAX
:
cuda_stream_priority
=
max
;
break
;
case
inference
::
ModelOptimizationPolicy
::
PRIORITY_MIN
:
cuda_stream_priority
=
min
;
break
;
default:
cuda_stream_priority
=
0
;
break
;
}
return
cuda_stream_priority
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_config_cuda.h
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <stdint.h>
#include "model_config.pb.h"
namespace
triton
{
namespace
core
{
/// Get the CUDA stream priority for a given ModelPriority
/// \param priority The inference::ModelOptimizationPolicy::ModelPriority
/// priority. \param cuda_stream_priority Returns the CUDA stream priority.
/// \return The error status.
int
GetCudaStreamPriority
(
inference
::
ModelOptimizationPolicy
::
ModelPriority
priority
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_config_utils.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "model_config_utils.h"
#include <google/protobuf/util/json_util.h>
#include <deque>
#include <mutex>
#include <set>
#include "constants.h"
#include "cuda_utils.h"
#include "filesystem.h"
#include "triton/common/logging.h"
#define TRITONJSON_STATUSTYPE triton::core::Status
#define TRITONJSON_STATUSRETURN(M) \
return triton::core::Status(triton::core::Status::Code::INTERNAL, (M))
#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success
#include "triton/common/triton_json.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace
triton
{
namespace
core
{
namespace
{
#ifdef TRITON_ENABLE_ENSEMBLE
struct
EnsembleTensor
{
EnsembleTensor
(
bool
isOutput
)
:
ready
(
false
),
isOutput
(
isOutput
)
{}
bool
ready
;
bool
isOutput
;
std
::
vector
<
EnsembleTensor
*>
prev_nodes
;
std
::
vector
<
EnsembleTensor
*>
next_nodes
;
};
/// Build a graph that represents the data flow in the ensemble specified in
/// given model config. the node (ensemble tensor) in the graph can be looked
/// up using its name as key.
/// \param ensemble_config The model configuration that specifies
/// ensemble_scheduling field.
/// \param keyed_ensemble_graph Returned the ensemble graph.
/// \return The error status. A non-OK status indicates the build fails because
/// the ensemble configuration is not valid.
Status
BuildEnsembleGraph
(
const
inference
::
ModelConfig
&
config
,
std
::
unordered_map
<
std
::
string
,
EnsembleTensor
>&
keyed_ensemble_graph
)
{
keyed_ensemble_graph
.
clear
();
size_t
step_idx
=
0
;
for
(
const
auto
&
element
:
config
.
ensemble_scheduling
().
step
())
{
if
(
element
.
model_name
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify 'model_name' in step "
+
std
::
to_string
(
step_idx
)
+
" of ensemble '"
+
config
.
name
()
+
"'"
);
}
if
(
element
.
input_map
().
size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify 'input_map' in step "
+
std
::
to_string
(
step_idx
)
+
" of ensemble '"
+
config
.
name
()
+
"'"
);
}
if
(
element
.
output_map
().
size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify 'output_map' in step "
+
std
::
to_string
(
step_idx
)
+
" of ensemble '"
+
config
.
name
()
+
"'"
);
}
// Link ensemble tensors
std
::
vector
<
EnsembleTensor
*>
tensor_as_output
;
for
(
const
auto
&
output_map
:
element
.
output_map
())
{
auto
it
=
keyed_ensemble_graph
.
find
(
output_map
.
second
);
if
(
it
!=
keyed_ensemble_graph
.
end
())
{
if
(
it
->
second
.
isOutput
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble tensor '"
+
it
->
first
+
"' can appear in an output map only once for ensemble '"
+
config
.
name
()
+
"' step "
+
std
::
to_string
(
step_idx
));
}
else
{
it
->
second
.
isOutput
=
true
;
}
}
else
{
it
=
keyed_ensemble_graph
.
emplace
(
std
::
make_pair
(
output_map
.
second
,
EnsembleTensor
(
true
)))
.
first
;
}
tensor_as_output
.
push_back
(
&
(
it
->
second
));
}
std
::
set
<
std
::
string
>
model_inputs
;
for
(
const
auto
&
input_map
:
element
.
input_map
())
{
if
(
model_inputs
.
find
(
input_map
.
first
)
!=
model_inputs
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"input '"
+
input_map
.
first
+
"' in model '"
+
element
.
model_name
()
+
"' is mapped to multiple ensemble tensors for ensemble '"
+
config
.
name
()
+
"' step "
+
std
::
to_string
(
step_idx
));
}
else
{
model_inputs
.
emplace
(
input_map
.
first
);
}
auto
it
=
keyed_ensemble_graph
.
find
(
input_map
.
second
);
if
(
it
==
keyed_ensemble_graph
.
end
())
{
it
=
keyed_ensemble_graph
.
emplace
(
std
::
make_pair
(
input_map
.
second
,
EnsembleTensor
(
false
)))
.
first
;
}
for
(
auto
output
:
tensor_as_output
)
{
output
->
prev_nodes
.
push_back
(
&
(
it
->
second
));
it
->
second
.
next_nodes
.
push_back
(
output
);
}
}
step_idx
++
;
}
return
Status
::
Success
;
}
Status
ValidateEnsembleSchedulingConfig
(
const
inference
::
ModelConfig
&
config
)
{
if
(
config
.
platform
()
!=
kEnsemblePlatform
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble scheduling cannot be set for model '"
+
config
.
name
()
+
"' whose platform is not "
+
kEnsemblePlatform
);
}
if
(
config
.
instance_group
().
size
()
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group should not be specified for ensemble '"
+
config
.
name
()
+
"'"
);
}
if
(
config
.
has_optimization
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"optimization should not be specified for ensemble '"
+
config
.
name
()
+
"'"
);
}
if
(
config
.
model_warmup_size
()
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"model_warmup can not be specified for ensemble '"
+
config
.
name
()
+
"'"
);
}
// Make sure step is not empty and all fields are set
if
(
config
.
ensemble_scheduling
().
step_size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify 'step' for ensemble '"
+
config
.
name
()
+
"'"
);
}
std
::
unordered_map
<
std
::
string
,
EnsembleTensor
>
tensors
;
RETURN_IF_ERROR
(
BuildEnsembleGraph
(
config
,
tensors
));
// check data flow
std
::
deque
<
EnsembleTensor
*>
ready_queue
;
for
(
const
auto
&
input
:
config
.
input
())
{
auto
it
=
tensors
.
find
(
input
.
name
());
if
(
it
==
tensors
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble input '"
+
input
.
name
()
+
"' for ensemble "
+
config
.
name
()
+
"' is not used"
);
}
it
->
second
.
ready
=
true
;
ready_queue
.
push_back
(
&
(
it
->
second
));
}
while
(
!
ready_queue
.
empty
())
{
auto
&
ready_node
=
ready_queue
.
front
();
for
(
auto
&
next_node
:
ready_node
->
next_nodes
)
{
if
(
next_node
->
ready
)
{
continue
;
}
bool
next_node_ready
=
true
;
for
(
auto
&
prev_node
:
next_node
->
prev_nodes
)
{
if
(
!
prev_node
->
ready
)
{
next_node_ready
=
false
;
break
;
}
}
next_node
->
ready
=
next_node_ready
;
if
(
next_node_ready
)
{
ready_queue
.
push_back
(
next_node
);
}
}
ready_queue
.
pop_front
();
}
std
::
set
<
std
::
string
>
outputs
;
for
(
const
auto
&
output
:
config
.
output
())
{
auto
it
=
tensors
.
find
(
output
.
name
());
if
(
it
==
tensors
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble output '"
+
output
.
name
()
+
"' for ensemble "
+
config
.
name
()
+
"' is not used"
);
}
if
(
!
it
->
second
.
ready
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"output '"
+
output
.
name
()
+
"' for ensemble '"
+
config
.
name
()
+
"' is not written"
);
}
else
{
outputs
.
insert
(
it
->
first
);
}
}
// Check redundant ensemble tensors
for
(
const
auto
&
tensor
:
tensors
)
{
// skip ensemble outputs as they have been checked and can have no
// next nodes
if
(
outputs
.
find
(
tensor
.
first
)
!=
outputs
.
end
())
{
continue
;
}
if
(
!
tensor
.
second
.
ready
||
(
tensor
.
second
.
next_nodes
.
size
()
==
0
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble tensor '"
+
tensor
.
first
+
"' is unused in ensemble '"
+
config
.
name
()
+
"'"
);
}
}
return
Status
::
Success
;
}
#endif // TRITON_ENABLE_ENSEMBLE
template
<
class
ModelIO
>
Status
ValidateIOShape
(
const
ModelIO
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
message_prefix
=
""
)
{
if
(
io
.
name
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"must specify 'name'"
);
}
if
(
io
.
data_type
()
==
inference
::
DataType
::
TYPE_INVALID
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"model output must specify 'data_type'"
);
}
if
(
io
.
dims_size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"must specify 'dims'"
);
}
// If the configuration is non-batching, then no input or output
// reshape can be empty as that would mean that input or output was
// always empty (no data).
if
(
io
.
has_reshape
()
&&
(
io
.
reshape
().
shape_size
()
==
0
)
&&
(
max_batch_size
==
0
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"cannot have empty reshape for non-batching model as scalar "
"tensors are not supported"
);
}
for
(
auto
dim
:
io
.
dims
())
{
// Dimension cannot be 0.
if
((
dim
<
1
)
&&
(
dim
!=
triton
::
common
::
WILDCARD_DIM
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"dimension must be integer >= 1, or "
+
std
::
to_string
(
triton
::
common
::
WILDCARD_DIM
)
+
" to indicate a variable-size dimension"
);
}
}
if
(
io
.
has_reshape
())
{
// Zeros are not allowed in reshape.
for
(
auto
dim
:
io
.
reshape
().
shape
())
{
if
((
dim
<
1
)
&&
(
dim
!=
triton
::
common
::
WILDCARD_DIM
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"reshape dimensions must be integer >= 1, or "
+
std
::
to_string
(
triton
::
common
::
WILDCARD_DIM
)
+
" to indicate a variable-size dimension"
);
}
}
const
int64_t
dims_size
=
triton
::
common
::
GetElementCount
(
io
.
dims
());
const
int64_t
reshape_size
=
triton
::
common
::
GetElementCount
(
io
.
reshape
().
shape
());
// dims and reshape must both have same element count
// or both have variable-size dimension.
// Special case for empty reshape... expect dims to have element
// count of 1.
if
((
dims_size
!=
reshape_size
)
&&
((
reshape_size
!=
0
)
||
(
dims_size
!=
1
)))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"has different size for dims and reshape"
);
}
// shape contains variable-size dimension, in this case we compare if
// each pair of the trunks separated by variable-size dimension has
// the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6]
// is valid reshape as 2 * 4 = 8 and 6 = 1 * 6.
if
(
dims_size
==
-
1
)
{
std
::
vector
<
int64_t
>
dim_element_cnts
;
std
::
vector
<
int64_t
>
reshape_element_cnts
;
int64_t
current_cnt
=
1
;
for
(
const
auto
&
dim
:
io
.
dims
())
{
if
(
dim
!=
-
1
)
{
current_cnt
*=
dim
;
}
else
{
dim_element_cnts
.
push_back
(
current_cnt
);
current_cnt
=
1
;
}
}
dim_element_cnts
.
push_back
(
current_cnt
);
current_cnt
=
1
;
for
(
const
auto
&
dim
:
io
.
reshape
().
shape
())
{
if
(
dim
!=
-
1
)
{
current_cnt
*=
dim
;
}
else
{
reshape_element_cnts
.
push_back
(
current_cnt
);
current_cnt
=
1
;
}
}
reshape_element_cnts
.
push_back
(
current_cnt
);
if
(
dim_element_cnts
.
size
()
!=
reshape_element_cnts
.
size
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"has different number of variable-size dimensions for dims "
"and reshape"
);
}
for
(
size_t
idx
=
0
;
idx
<
dim_element_cnts
.
size
();
idx
++
)
{
if
(
dim_element_cnts
[
idx
]
!=
reshape_element_cnts
[
idx
])
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message_prefix
+
"has different size for dims and reshape"
);
}
}
}
}
return
Status
::
Success
;
}
}
// namespace
Status
GetModelVersionFromPath
(
const
std
::
string
&
path
,
int64_t
*
version
)
{
auto
version_dir
=
BaseName
(
path
);
// Determine the version from the last segment of 'path'
try
{
*
version
=
std
::
atoll
(
version_dir
.
c_str
());
}
catch
(...)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"unable to determine model version from "
+
path
);
}
return
Status
::
Success
;
}
Status
GetBooleanSequenceControlProperties
(
const
inference
::
ModelSequenceBatching
&
batcher
,
const
std
::
string
&
model_name
,
const
inference
::
ModelSequenceBatching
::
Control
::
Kind
control_kind
,
const
bool
required
,
std
::
string
*
tensor_name
,
inference
::
DataType
*
tensor_datatype
,
float
*
fp32_false_value
,
float
*
fp32_true_value
,
int32_t
*
int32_false_value
,
int32_t
*
int32_true_value
,
bool
*
bool_false_value
,
bool
*
bool_true_value
)
{
// Make sure same tensor is not configured for multiple controls
std
::
set
<
std
::
string
>
seen_tensors
;
// Make sure the control kind is not mentioned multiple times.
bool
seen_control
=
false
;
for
(
const
auto
&
control_input
:
batcher
.
control_input
())
{
if
(
control_input
.
name
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control tensor must have a name for "
+
model_name
);
}
if
(
seen_tensors
.
find
(
control_input
.
name
())
!=
seen_tensors
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control tensor '"
+
control_input
.
name
()
+
"' is specified for multiple control kinds for "
+
model_name
);
}
seen_tensors
.
insert
(
control_input
.
name
());
for
(
const
auto
&
c
:
control_input
.
control
())
{
if
(
c
.
kind
()
==
control_kind
)
{
if
(
seen_control
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching specifies multiple "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" tensors for "
+
model_name
);
}
*
tensor_name
=
control_input
.
name
();
seen_control
=
true
;
// Make sure only one of int, float, or bool type is specified.
if
(
!
((
c
.
int32_false_true_size
()
!=
0
)
||
(
c
.
fp32_false_true_size
()
!=
0
)
||
(
c
.
bool_false_true_size
()
!=
0
)))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching must specify either 'int32_false_true', "
"'fp32_false_true' or 'bool_false_true' for "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" for "
+
model_name
);
}
else
if
(
((
c
.
int32_false_true_size
()
!=
0
)
&&
(
c
.
fp32_false_true_size
()
!=
0
))
||
((
c
.
int32_false_true_size
()
!=
0
)
&&
(
c
.
bool_false_true_size
()
!=
0
))
||
((
c
.
fp32_false_true_size
()
!=
0
)
&&
(
c
.
bool_false_true_size
()
!=
0
)))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching specifies more than one from "
"'int32_false_true', 'fp32_false_true' and 'bool_false_true' "
"for "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" for "
+
model_name
);
}
if
(
c
.
int32_false_true_size
()
>
0
)
{
if
(
c
.
int32_false_true_size
()
!=
2
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control 'int32_false_true' must have "
"exactly 2 entries for "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" for "
+
model_name
);
}
if
(
tensor_datatype
!=
nullptr
)
{
*
tensor_datatype
=
inference
::
DataType
::
TYPE_INT32
;
}
if
(
int32_false_value
!=
nullptr
)
{
*
int32_false_value
=
c
.
int32_false_true
(
0
);
}
if
(
int32_true_value
!=
nullptr
)
{
*
int32_true_value
=
c
.
int32_false_true
(
1
);
}
}
else
if
(
c
.
fp32_false_true_size
()
>
0
)
{
if
(
c
.
fp32_false_true_size
()
!=
2
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control 'fp32_false_true' must have exactly "
"2 entries for "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" for "
+
model_name
);
}
if
(
tensor_datatype
!=
nullptr
)
{
*
tensor_datatype
=
inference
::
DataType
::
TYPE_FP32
;
}
if
(
fp32_false_value
!=
nullptr
)
{
*
fp32_false_value
=
c
.
fp32_false_true
(
0
);
}
if
(
fp32_true_value
!=
nullptr
)
{
*
fp32_true_value
=
c
.
fp32_false_true
(
1
);
}
}
else
{
if
(
c
.
bool_false_true_size
()
!=
2
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control 'bool_false_true' must have exactly "
"2 entries for "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" for "
+
model_name
);
}
if
(
tensor_datatype
!=
nullptr
)
{
*
tensor_datatype
=
inference
::
DataType
::
TYPE_BOOL
;
}
if
(
bool_false_value
!=
nullptr
)
{
*
bool_false_value
=
c
.
bool_false_true
(
0
);
}
if
(
bool_true_value
!=
nullptr
)
{
*
bool_true_value
=
c
.
bool_false_true
(
1
);
}
}
}
}
}
if
(
!
seen_control
)
{
if
(
required
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control tensor must specify a "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" value for "
+
model_name
);
}
tensor_name
->
clear
();
}
return
Status
::
Success
;
}
Status
GetTypedSequenceControlProperties
(
const
inference
::
ModelSequenceBatching
&
batcher
,
const
std
::
string
&
model_name
,
const
inference
::
ModelSequenceBatching
::
Control
::
Kind
control_kind
,
const
bool
required
,
std
::
string
*
tensor_name
,
inference
::
DataType
*
tensor_datatype
)
{
// Make sure same tensor is not configured for multiple controls
std
::
set
<
std
::
string
>
seen_tensors
;
// Make sure the control kind is not mentioned multiple times.
bool
seen_control
=
false
;
for
(
const
auto
&
control_input
:
batcher
.
control_input
())
{
if
(
control_input
.
name
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control tensor must have a name for "
+
model_name
);
}
if
(
seen_tensors
.
find
(
control_input
.
name
())
!=
seen_tensors
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control tensor '"
+
control_input
.
name
()
+
"' is specified for multiple control kinds for "
+
model_name
);
}
seen_tensors
.
insert
(
control_input
.
name
());
for
(
const
auto
&
c
:
control_input
.
control
())
{
if
(
c
.
kind
()
==
control_kind
)
{
if
(
seen_control
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching specifies multiple "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" tensors for "
+
model_name
);
}
*
tensor_name
=
control_input
.
name
();
if
(
tensor_datatype
!=
nullptr
)
{
*
tensor_datatype
=
c
.
data_type
();
}
seen_control
=
true
;
if
((
c
.
int32_false_true_size
()
>
0
)
||
(
c
.
fp32_false_true_size
()
>
0
)
||
(
c
.
bool_false_true_size
()
>
0
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching must not specify either 'int32_false_true', "
"'fp32_false_true' or 'bool_false_true' for "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" for "
+
model_name
);
}
}
}
}
if
(
!
seen_control
)
{
if
(
required
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching control tensor must specify a "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
control_kind
)
+
" value for "
+
model_name
);
}
tensor_name
->
clear
();
}
return
Status
::
Success
;
}
Status
GetNormalizedModelConfig
(
const
std
::
string
&
model_name
,
const
std
::
string
&
path
,
const
double
min_compute_capability
,
inference
::
ModelConfig
*
config
)
{
// Server-side autofill only sets certain backend fields for the models that
// belong to limited backends for backwards-compatibility. See TensorRT
// backend, ONNX Runtime backend, OpenVINO backend, TensorFLow backend, and
// PyTorch backend.
// Extracting detailed information is delegated to the backend implementation
// to auto-complete.
RETURN_IF_ERROR
(
AutoCompleteBackendFields
(
model_name
,
std
::
string
(
path
),
config
));
LOG_VERBOSE
(
1
)
<<
"Server side auto-completed config: "
<<
config
->
DebugString
();
RETURN_IF_ERROR
(
NormalizeModelConfig
(
min_compute_capability
,
config
));
return
Status
::
Success
;
}
Status
NormalizeModelConfig
(
const
double
min_compute_capability
,
inference
::
ModelConfig
*
config
)
{
// If version_policy is not specified, default to Latest 1 version.
if
(
!
config
->
has_version_policy
())
{
inference
::
ModelVersionPolicy
::
Latest
latest
;
latest
.
set_num_versions
(
1
);
config
->
mutable_version_policy
()
->
mutable_latest
()
->
CopyFrom
(
latest
);
}
// If dynamic batching is specified...
if
(
config
->
has_dynamic_batching
())
{
// If preferred batch size is not specified set it to
// max-batch-size.
if
(
config
->
dynamic_batching
().
preferred_batch_size
().
size
()
==
0
)
{
auto
mutable_preferred_batch_size
=
config
->
mutable_dynamic_batching
()
->
mutable_preferred_batch_size
();
if
(
config
->
max_batch_size
()
>
0
)
{
mutable_preferred_batch_size
->
Add
(
config
->
max_batch_size
());
}
}
}
// If sequence batching is specified...
if
(
config
->
has_sequence_batching
())
{
// Set default idle is not specified.
if
(
config
->
sequence_batching
().
max_sequence_idle_microseconds
()
==
0
)
{
config
->
mutable_sequence_batching
()
->
set_max_sequence_idle_microseconds
(
SEQUENCE_IDLE_DEFAULT_MICROSECONDS
);
}
if
(
config
->
sequence_batching
().
has_oldest
())
{
// If preferred batch size is not specified set it to
// max-batch-size.
if
(
config
->
sequence_batching
().
oldest
().
preferred_batch_size
().
size
()
==
0
)
{
auto
mutable_preferred_batch_size
=
config
->
mutable_sequence_batching
()
->
mutable_oldest
()
->
mutable_preferred_batch_size
();
if
(
config
->
max_batch_size
()
>
0
)
{
mutable_preferred_batch_size
->
Add
(
config
->
max_batch_size
());
}
}
}
}
// If model ensembling is specified, don't attempt to normalize instance_group
// as it is not allowed in ensemble scheduling
if
(
!
config
->
has_ensemble_scheduling
())
{
auto
optimization
=
config
->
mutable_optimization
();
if
(
!
optimization
->
has_input_pinned_memory
())
{
optimization
->
mutable_input_pinned_memory
()
->
set_enable
(
true
);
}
if
(
!
optimization
->
has_output_pinned_memory
())
{
optimization
->
mutable_output_pinned_memory
()
->
set_enable
(
true
);
}
}
return
Status
::
Success
;
}
Status
NormalizeInstanceGroup
(
const
double
min_compute_capability
,
const
std
::
vector
<
inference
::
ModelInstanceGroup
>&
preferred_groups
,
inference
::
ModelConfig
*
config
)
{
// Instance group setting doesn't apply to ensemble
if
(
config
->
has_ensemble_scheduling
())
{
return
Status
::
Success
;
}
// Creates a set of supported GPU device ids
std
::
set
<
int
>
supported_gpus
;
#ifdef TRITON_ENABLE_GPU
// Get the total number of GPUs from the runtime library.
Status
status
=
GetSupportedGPUs
(
&
supported_gpus
,
min_compute_capability
);
if
(
!
status
.
IsOk
())
{
return
status
;
}
#endif // TRITON_ENABLE_GPU
// Make sure there is at least one instance_group.
if
(
config
->
instance_group
().
empty
())
{
inference
::
ModelInstanceGroup
*
group
=
config
->
add_instance_group
();
group
->
set_name
(
config
->
name
());
for
(
const
auto
&
pg
:
preferred_groups
)
{
group
->
set_kind
(
pg
.
kind
());
group
->
set_count
(
pg
.
count
());
// handle preferred GPU setting differently based on kind
if
(
pg
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_GPU
)
{
// Don't use preferred group with KIND_GPU if there is no GPU.
if
(
supported_gpus
.
empty
())
{
continue
;
}
// If preferred group sets GPUs, limit deployment onto those that
// are also listed in supported gpus
if
(
!
pg
.
gpus
().
empty
())
{
for
(
const
int32_t
gid
:
pg
.
gpus
())
{
if
(
supported_gpus
.
find
(
gid
)
!=
supported_gpus
.
end
())
{
group
->
add_gpus
(
gid
);
}
}
}
break
;
}
else
if
(
pg
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_AUTO
)
{
// if AUTO, then set preferred GPU as is, to align with KIND_AUTO
// deduction specified below
for
(
const
int32_t
gid
:
pg
.
gpus
())
{
group
->
add_gpus
(
gid
);
}
break
;
}
// Other kind should not set GPUs
break
;
}
}
// Assign default name, kind and count to each instance group that
// doesn't give those values explicitly. For KIND_GPU, set GPUs to
// all available if not specified explicitly.
size_t
cnt
=
0
;
for
(
auto
&
group
:
*
config
->
mutable_instance_group
())
{
// Name
if
(
group
.
name
().
empty
())
{
group
.
set_name
(
config
->
name
()
+
"_"
+
std
::
to_string
(
cnt
));
}
cnt
++
;
// For KIND_AUTO... if there are no GPUs or if any of the listed
// 'gpu's are not present, then use KIND_CPU.
if
(
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_AUTO
)
{
if
(
supported_gpus
.
empty
())
{
group
.
set_kind
(
inference
::
ModelInstanceGroup
::
KIND_CPU
);
}
else
{
for
(
const
int32_t
gid
:
group
.
gpus
())
{
if
(
supported_gpus
.
find
(
gid
)
==
supported_gpus
.
end
())
{
group
.
set_kind
(
inference
::
ModelInstanceGroup
::
KIND_CPU
);
break
;
}
}
}
if
(
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_AUTO
)
{
group
.
set_kind
(
inference
::
ModelInstanceGroup
::
KIND_GPU
);
}
}
// KIND is resolved at this point
for
(
const
auto
&
pg
:
preferred_groups
)
{
if
(
group
.
kind
()
!=
pg
.
kind
())
{
continue
;
}
// Limit the GPU setting within what is specified in the preferred group,
// if no available GPU then skip to next preferred group
if
((
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_GPU
)
&&
group
.
gpus
().
empty
()
&&
!
pg
.
gpus
().
empty
())
{
for
(
const
int32_t
gid
:
pg
.
gpus
())
{
if
(
supported_gpus
.
find
(
gid
)
!=
supported_gpus
.
end
())
{
group
.
add_gpus
(
gid
);
}
}
if
(
group
.
gpus
().
empty
())
{
continue
;
}
}
if
((
group
.
count
()
<
1
)
&&
(
pg
.
count
()
>
0
))
{
group
.
set_count
(
pg
.
count
());
}
}
// Set Triton default if the fields are not set from preferred group
// Count
if
(
group
.
count
()
<
1
)
{
RETURN_IF_ERROR
(
SetDefaultInstanceCount
(
&
group
,
config
->
backend
()));
}
// GPUs
if
((
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_GPU
)
&&
(
group
.
gpus
().
size
()
==
0
))
{
for
(
auto
d
:
supported_gpus
)
{
group
.
add_gpus
(
d
);
}
}
}
return
Status
::
Success
;
}
Status
LocalizePythonBackendExecutionEnvironmentPath
(
const
std
::
string
&
model_path
,
inference
::
ModelConfig
*
config
,
std
::
shared_ptr
<
LocalizedPath
>*
localized_model_dir
)
{
if
(
config
->
backend
()
==
"python"
)
{
if
(
config
->
parameters
().
contains
(
"EXECUTION_ENV_PATH"
))
{
// Read EXECUTION_ENV_PATH
std
::
string
exec_env_path
=
config
->
parameters
().
at
(
"EXECUTION_ENV_PATH"
).
string_value
();
// Replace model directory variable with model_path
std
::
string
model_dir_var
=
"$$TRITON_MODEL_DIRECTORY"
;
if
(
exec_env_path
.
substr
(
0
,
model_dir_var
.
size
())
==
model_dir_var
)
{
exec_env_path
.
replace
(
0
,
model_dir_var
.
size
(),
model_path
);
}
// Collapse any .. in the path
std
::
string
abs_exec_env_path
;
std
::
size_t
prev_pos
=
exec_env_path
.
size
();
std
::
size_t
pos
=
exec_env_path
.
find_last_of
(
'/'
,
prev_pos
-
1
);
int
skip
=
0
;
while
(
pos
!=
std
::
string
::
npos
&&
prev_pos
>
0
)
{
if
(
!
skip
)
{
abs_exec_env_path
=
exec_env_path
.
substr
(
pos
,
prev_pos
-
pos
)
+
abs_exec_env_path
;
}
skip
=
skip
>
0
?
skip
-
1
:
skip
;
if
(
pos
>=
3
&&
exec_env_path
.
substr
(
pos
-
3
,
3
)
==
"/.."
)
{
skip
+=
2
;
}
prev_pos
=
pos
;
pos
=
exec_env_path
.
find_last_of
(
'/'
,
prev_pos
-
1
);
}
abs_exec_env_path
=
exec_env_path
.
substr
(
0
,
prev_pos
)
+
abs_exec_env_path
;
// Localize iff abs_exec_env_path is outside the model directory
std
::
string
model_path_slash
=
model_path
.
back
()
==
'/'
?
model_path
:
model_path
+
"/"
;
if
(
abs_exec_env_path
.
substr
(
0
,
model_path_slash
.
size
())
!=
model_path_slash
)
{
// Localize the file
std
::
shared_ptr
<
LocalizedPath
>
localized_exec_env_path
;
RETURN_IF_ERROR
(
LocalizePath
(
abs_exec_env_path
,
&
localized_exec_env_path
));
// Persist the localized temporary path
(
*
localized_model_dir
)
->
other_localized_path
.
push_back
(
localized_exec_env_path
);
// Rewrite EXECUTION_ENV_PATH
config
->
mutable_parameters
()
->
at
(
"EXECUTION_ENV_PATH"
)
.
set_string_value
(
localized_exec_env_path
->
Path
());
}
}
}
return
Status
::
Success
;
}
Status
SetDefaultInstanceCount
(
inference
::
ModelInstanceGroup
*
group
,
const
std
::
string
&
backend
)
{
group
->
set_count
(
1
);
// Backends opt into the default_cpu_instance_count since
// some backends (pytorch, OpenVINO) don't perform well/have high overhead
// when using multiple instances.
const
int
default_cpu_instance_count
=
2
;
bool
use_default_cpu_instance_count
=
(
backend
==
kTensorFlowBackend
)
||
(
backend
==
kOnnxRuntimeBackend
);
if
(
group
->
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_CPU
&&
use_default_cpu_instance_count
)
{
group
->
set_count
(
default_cpu_instance_count
);
}
return
Status
::
Success
;
}
Status
AutoCompleteBackendFields
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
inference
::
ModelConfig
*
config
)
{
std
::
set
<
std
::
string
>
version_dirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
model_path
,
&
version_dirs
));
// There must be at least one version directory that we can inspect to
// attempt to determine the platform. If not, we skip autofill with file name.
// For now we allow multiple versions and only inspect the first verison
// directory to ensure it is valid. We can add more aggressive checks later.
const
bool
has_version
=
(
version_dirs
.
size
()
!=
0
);
const
auto
version_path
=
has_version
?
JoinPath
({
model_path
,
*
(
version_dirs
.
begin
())})
:
""
;
std
::
set
<
std
::
string
>
version_dir_content
;
if
(
has_version
)
{
RETURN_IF_ERROR
(
GetDirectoryContents
(
version_path
,
&
version_dir_content
));
}
// If the model name is not given in the configuration, set if based
// on the model path.
if
(
config
->
name
().
empty
())
{
config
->
set_name
(
model_name
);
}
// Trying to fill the 'backend', 'default_model_filename' field.
// TensorFlow
// For TF backend, the platform is required
if
(
config
->
platform
().
empty
())
{
// Check 'backend', 'default_model_filename', and the actual directory
// to determine the platform
if
(
config
->
backend
().
empty
()
||
(
config
->
backend
()
==
kTensorFlowBackend
))
{
if
(
config
->
default_model_filename
()
==
kTensorFlowSavedModelFilename
)
{
config
->
set_platform
(
kTensorFlowSavedModelPlatform
);
}
else
if
(
config
->
default_model_filename
()
==
kTensorFlowGraphDefFilename
)
{
config
->
set_platform
(
kTensorFlowGraphDefPlatform
);
}
else
if
(
config
->
default_model_filename
().
empty
()
&&
has_version
)
{
bool
is_dir
=
false
;
if
(
version_dir_content
.
find
(
kTensorFlowSavedModelFilename
)
!=
version_dir_content
.
end
())
{
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
version_path
,
kTensorFlowSavedModelFilename
}),
&
is_dir
));
if
(
is_dir
)
{
config
->
set_platform
(
kTensorFlowSavedModelPlatform
);
}
}
if
(
version_dir_content
.
find
(
kTensorFlowGraphDefFilename
)
!=
version_dir_content
.
end
())
{
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
version_path
,
kTensorFlowGraphDefFilename
}),
&
is_dir
));
if
(
!
is_dir
)
{
config
->
set_platform
(
kTensorFlowGraphDefPlatform
);
}
}
}
}
}
// Fill 'backend' and 'default_model_filename' if missing
if
((
config
->
platform
()
==
kTensorFlowSavedModelPlatform
)
||
(
config
->
platform
()
==
kTensorFlowGraphDefPlatform
))
{
if
(
config
->
backend
().
empty
())
{
config
->
set_backend
(
kTensorFlowBackend
);
}
if
(
config
->
default_model_filename
().
empty
())
{
if
(
config
->
platform
()
==
kTensorFlowSavedModelPlatform
)
{
config
->
set_default_model_filename
(
kTensorFlowSavedModelFilename
);
}
else
{
config
->
set_default_model_filename
(
kTensorFlowGraphDefFilename
);
}
}
return
Status
::
Success
;
}
// TensorRT
if
(
config
->
backend
().
empty
())
{
if
((
config
->
platform
()
==
kTensorRTPlanPlatform
)
||
(
config
->
default_model_filename
()
==
kTensorRTPlanFilename
))
{
config
->
set_backend
(
kTensorRTBackend
);
}
else
if
(
config
->
platform
().
empty
()
&&
config
->
default_model_filename
().
empty
()
&&
has_version
)
{
bool
is_dir
=
false
;
if
(
version_dir_content
.
find
(
kTensorRTPlanFilename
)
!=
version_dir_content
.
end
())
{
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
version_path
,
kTensorRTPlanFilename
}),
&
is_dir
));
if
(
!
is_dir
)
{
config
->
set_backend
(
kTensorRTBackend
);
}
}
}
}
if
(
config
->
backend
()
==
kTensorRTBackend
)
{
if
(
config
->
platform
().
empty
())
{
config
->
set_platform
(
kTensorRTPlanPlatform
);
}
if
(
config
->
default_model_filename
().
empty
())
{
config
->
set_default_model_filename
(
kTensorRTPlanFilename
);
}
return
Status
::
Success
;
}
// ONNXRuntime
if
(
config
->
backend
().
empty
())
{
if
((
config
->
platform
()
==
kOnnxRuntimeOnnxPlatform
)
||
(
config
->
default_model_filename
()
==
kOnnxRuntimeOnnxFilename
))
{
config
->
set_backend
(
kOnnxRuntimeBackend
);
}
else
if
(
config
->
platform
().
empty
()
&&
config
->
default_model_filename
().
empty
()
&&
has_version
)
{
if
(
version_dir_content
.
find
(
kOnnxRuntimeOnnxFilename
)
!=
version_dir_content
.
end
())
{
// ONNX model can be a file or a directory in the case of large model
config
->
set_backend
(
kOnnxRuntimeBackend
);
}
}
}
if
(
config
->
backend
()
==
kOnnxRuntimeBackend
)
{
if
(
config
->
platform
().
empty
())
{
config
->
set_platform
(
kOnnxRuntimeOnnxPlatform
);
}
if
(
config
->
default_model_filename
().
empty
())
{
config
->
set_default_model_filename
(
kOnnxRuntimeOnnxFilename
);
}
return
Status
::
Success
;
}
// OpenVINO
if
(
config
->
backend
().
empty
())
{
if
(
config
->
default_model_filename
()
==
kOpenVINORuntimeOpenVINOFilename
)
{
config
->
set_backend
(
kOpenVINORuntimeBackend
);
}
else
if
(
config
->
platform
().
empty
()
&&
config
->
default_model_filename
().
empty
()
&&
has_version
)
{
if
(
version_dir_content
.
find
(
kOpenVINORuntimeOpenVINOFilename
)
!=
version_dir_content
.
end
())
{
config
->
set_backend
(
kOpenVINORuntimeBackend
);
}
}
}
if
(
config
->
backend
()
==
kOpenVINORuntimeBackend
)
{
if
(
config
->
default_model_filename
().
empty
())
{
config
->
set_default_model_filename
(
kOpenVINORuntimeOpenVINOFilename
);
}
return
Status
::
Success
;
}
// PyTorch (TorchScript, LibTorch)
if
(
config
->
backend
().
empty
())
{
if
((
config
->
platform
()
==
kPyTorchLibTorchPlatform
)
||
(
config
->
default_model_filename
()
==
kPyTorchLibTorchFilename
))
{
config
->
set_backend
(
kPyTorchBackend
);
}
else
if
(
config
->
platform
().
empty
()
&&
config
->
default_model_filename
().
empty
()
&&
has_version
)
{
bool
is_dir
=
false
;
if
(
version_dir_content
.
find
(
kPyTorchLibTorchFilename
)
!=
version_dir_content
.
end
())
{
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
version_path
,
kPyTorchLibTorchFilename
}),
&
is_dir
));
if
(
!
is_dir
)
{
config
->
set_backend
(
kPyTorchBackend
);
}
}
}
}
if
(
config
->
backend
()
==
kPyTorchBackend
)
{
if
(
config
->
platform
().
empty
())
{
config
->
set_platform
(
kPyTorchLibTorchPlatform
);
}
if
(
config
->
default_model_filename
().
empty
())
{
config
->
set_default_model_filename
(
kPyTorchLibTorchFilename
);
}
return
Status
::
Success
;
}
// Python
if
(
config
->
backend
().
empty
())
{
if
(
config
->
default_model_filename
()
==
kPythonFilename
)
{
config
->
set_backend
(
kPythonBackend
);
}
else
if
(
config
->
platform
().
empty
()
&&
config
->
default_model_filename
().
empty
()
&&
has_version
)
{
if
(
version_dir_content
.
find
(
kPythonFilename
)
!=
version_dir_content
.
end
())
{
config
->
set_backend
(
kPythonBackend
);
}
}
}
if
(
config
->
backend
()
==
kPythonBackend
)
{
if
(
config
->
default_model_filename
().
empty
())
{
config
->
set_default_model_filename
(
kPythonFilename
);
}
return
Status
::
Success
;
}
// Custom Backend
// For now, only do the narrowest case, where no info is given in the config.
if
(
config
->
backend
().
empty
()
&&
config
->
platform
().
empty
()
&&
config
->
default_model_filename
().
empty
())
{
LOG_VERBOSE
(
1
)
<<
"Could not infer supported backend, so attempting "
"autofill of custom backend."
;
// Since we lazily load the backends, we let the model tell us what backend
// to load. We must assume that if the model name conforms to the required
// shape, we parse the backend name out of the model file name. i.e.
// model.identity will set the backend to "identity".
const
std
::
string
delimiter
=
"."
;
size_t
pos
=
model_name
.
find
(
delimiter
,
0
);
if
(
pos
==
std
::
string
::
npos
)
{
return
Status
(
triton
::
common
::
Error
::
Code
::
INVALID_ARG
,
(
"Invalid model name: Could not determine backend for model '"
+
model_name
+
"' with no backend in model configuration. Expected model name of "
"the form 'model.<backend_name>'."
));
}
const
std
::
string
backend_name
=
model_name
.
substr
(
pos
+
1
,
std
::
string
::
npos
);
config
->
set_backend
(
backend_name
);
config
->
set_default_model_filename
(
(
std
::
string
(
"model."
)
+
backend_name
).
c_str
());
return
Status
::
Success
;
}
return
Status
::
Success
;
}
Status
ValidateModelIOConfig
(
const
inference
::
ModelConfig
&
config
)
{
Status
status
;
for
(
const
auto
&
io
:
config
.
input
())
{
status
=
ValidateModelInput
(
io
,
config
.
max_batch_size
(),
config
.
platform
());
if
(
!
status
.
IsOk
())
{
return
Status
(
status
.
StatusCode
(),
status
.
Message
()
+
" for "
+
config
.
name
());
}
}
for
(
const
auto
&
io
:
config
.
output
())
{
status
=
ValidateModelOutput
(
io
,
config
.
max_batch_size
(),
config
.
platform
());
if
(
!
status
.
IsOk
())
{
return
Status
(
status
.
StatusCode
(),
status
.
Message
()
+
" for "
+
config
.
name
());
}
}
status
=
ValidateBatchIO
(
config
);
if
(
!
status
.
IsOk
())
{
return
Status
(
status
.
StatusCode
(),
status
.
Message
()
+
" for "
+
config
.
name
());
}
return
Status
::
Success
;
}
Status
ValidateBatchIO
(
const
inference
::
ModelConfig
&
config
)
{
std
::
set
<
std
::
string
>
input_names
;
std
::
set
<
std
::
string
>
output_names
;
for
(
const
auto
&
io
:
config
.
input
())
{
input_names
.
emplace
(
io
.
name
());
}
for
(
const
auto
&
io
:
config
.
output
())
{
output_names
.
emplace
(
io
.
name
());
}
for
(
const
auto
&
batch_io
:
config
.
batch_input
())
{
switch
(
batch_io
.
kind
())
{
case
inference
::
BatchInput
::
BATCH_ELEMENT_COUNT
:
case
inference
::
BatchInput
::
BATCH_ACCUMULATED_ELEMENT_COUNT
:
case
inference
::
BatchInput
::
BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO
:
case
inference
::
BatchInput
::
BATCH_MAX_ELEMENT_COUNT_AS_SHAPE
:
case
inference
::
BatchInput
::
BATCH_ITEM_SHAPE
:
case
inference
::
BatchInput
::
BATCH_ITEM_SHAPE_FLATTEN
:
{
if
(
batch_io
.
source_input_size
()
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"batch input kind '"
+
inference
::
BatchInput
::
Kind_Name
(
batch_io
.
kind
())
+
"' expects 1 source input, got "
+
std
::
to_string
(
batch_io
.
source_input_size
()));
}
break
;
}
default:
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown batch input kind '"
+
inference
::
BatchInput
::
Kind_Name
(
batch_io
.
kind
())
+
"'"
);
}
if
((
batch_io
.
data_type
()
!=
inference
::
DataType
::
TYPE_INT32
)
&&
(
batch_io
.
data_type
()
!=
inference
::
DataType
::
TYPE_FP32
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"batch input data type must be TYPE_INT32 or TYPE_FP32"
);
}
for
(
const
auto
&
source_name
:
batch_io
.
source_input
())
{
if
(
input_names
.
find
(
source_name
)
==
input_names
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown source input name '"
+
source_name
+
"'"
);
}
}
}
for
(
const
auto
&
batch_io
:
config
.
batch_output
())
{
switch
(
batch_io
.
kind
())
{
case
inference
::
BatchOutput
::
BATCH_SCATTER_WITH_INPUT_SHAPE
:
{
if
(
batch_io
.
source_input_size
()
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"batch output kind '"
+
inference
::
BatchOutput
::
Kind_Name
(
batch_io
.
kind
())
+
"' expects 1 source input, got "
+
std
::
to_string
(
batch_io
.
source_input_size
()));
}
break
;
}
default:
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown batch output kind '"
+
inference
::
BatchOutput
::
Kind_Name
(
batch_io
.
kind
())
+
"'"
);
}
for
(
const
auto
&
source_name
:
batch_io
.
source_input
())
{
if
(
input_names
.
find
(
source_name
)
==
input_names
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown source input name '"
+
source_name
+
"'"
);
}
}
std
::
set
<
std
::
string
>
target_names
;
for
(
const
auto
&
target_name
:
batch_io
.
target_name
())
{
if
(
output_names
.
find
(
target_name
)
==
output_names
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown target output name '"
+
target_name
+
"'"
);
}
if
(
target_names
.
emplace
(
target_name
).
second
==
false
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"target output name '"
+
target_name
+
"' can only be specified once"
);
}
}
}
return
Status
::
Success
;
}
Status
ValidateModelConfig
(
const
inference
::
ModelConfig
&
config
,
const
double
min_compute_capability
)
{
if
(
config
.
name
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"model configuration must specify 'name'"
);
}
if
(
config
.
backend
().
empty
())
{
// Expect backend is not empty unless it is ensemble platform.
#ifdef TRITON_ENABLE_ENSEMBLE
if
(
config
.
platform
()
!=
kEnsemblePlatform
)
#endif // TRITON_ENABLE_ENSEMBLE
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected platform type '"
+
config
.
platform
()
+
"' for "
+
config
.
name
());
}
#ifdef TRITON_ENABLE_ENSEMBLE
else
if
(
config
.
platform
()
==
kEnsemblePlatform
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Ensemble model '"
+
config
.
name
()
+
"' must have platform type '"
+
config
.
platform
()
+
"' and empty backend type"
);
}
#endif // TRITON_ENABLE_ENSEMBLE
if
(
config
.
platform
().
empty
()
&&
config
.
backend
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify 'platform' or 'backend' for '"
+
config
.
name
()
+
"'"
);
}
// Ensure both platform and backend are referring to known backend,
// or both referring to unknown backend for user-provided backend.
if
(
GetBackendTypeFromPlatform
(
config
.
platform
())
!=
GetBackendType
(
config
.
backend
()))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected 'platform' and 'backend' pair, got:"
+
config
.
platform
()
+
", "
+
config
.
backend
());
}
if
(
config
.
max_batch_size
()
<
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"'max_batch_size' must be non-negative value for "
+
config
.
name
());
}
if
(
!
config
.
has_version_policy
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify 'version policy' for "
+
config
.
name
());
}
// If dynamic batching is specified make sure the preferred batch
// sizes are positive and don't exceed maximum batch size.
if
(
config
.
has_dynamic_batching
())
{
for
(
const
auto
size
:
config
.
dynamic_batching
().
preferred_batch_size
())
{
if
(
size
<=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"dynamic batching preferred size must be positive for "
+
config
.
name
());
}
if
(
size
>
config
.
max_batch_size
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"dynamic batching preferred size must be <= max batch size for "
+
config
.
name
());
}
}
// Priority queue is specified
const
auto
priority_levels
=
config
.
dynamic_batching
().
priority_levels
();
if
(
priority_levels
!=
0
)
{
if
((
config
.
dynamic_batching
().
default_priority_level
()
==
0
)
||
(
config
.
dynamic_batching
().
default_priority_level
()
>
priority_levels
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"default priority level must be in range [1, "
+
std
::
to_string
(
priority_levels
)
+
"] for "
+
config
.
name
());
}
for
(
const
auto
&
queue_policy
:
config
.
dynamic_batching
().
priority_queue_policy
())
{
if
((
queue_policy
.
first
==
0
)
||
(
queue_policy
.
first
>
priority_levels
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"priority queue policy must have priority level in range [1, "
+
std
::
to_string
(
priority_levels
)
+
"] for "
+
config
.
name
());
}
}
}
// preserve ordering option will conflict with priorities and delay policy
if
(
config
.
dynamic_batching
().
preserve_ordering
())
{
if
(
priority_levels
>
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Only one priority level is allowed when 'preserve_ordering' is "
"true for "
+
config
.
name
());
}
const
auto
&
default_policy
=
config
.
dynamic_batching
().
default_queue_policy
();
if
((
default_policy
.
default_timeout_microseconds
()
!=
0
)
&&
(
default_policy
.
timeout_action
()
==
inference
::
ModelQueuePolicy
::
DELAY
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Queue policy can not have DELAY as timeout action when "
"'preserve_ordering' is true for "
+
config
.
name
());
}
// Also need to check policy in 'priority_queue_policy'
// for single priority case
for
(
const
auto
&
policy
:
config
.
dynamic_batching
().
priority_queue_policy
())
{
if
((
policy
.
second
.
default_timeout_microseconds
()
!=
0
)
&&
(
policy
.
second
.
timeout_action
()
==
inference
::
ModelQueuePolicy
::
DELAY
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Queue policy can not have DELAY as timeout action when "
"'preserve_ordering' is true for "
+
config
.
name
());
}
}
}
}
// If sequence batching is specified make sure the control is
// specified correctly.
if
(
config
.
has_sequence_batching
())
{
const
auto
&
batcher
=
config
.
sequence_batching
();
// Check boolean controls...
std
::
string
tensor_name
;
RETURN_IF_ERROR
(
GetBooleanSequenceControlProperties
(
batcher
,
config
.
name
(),
inference
::
ModelSequenceBatching
::
Control
::
CONTROL_SEQUENCE_START
,
false
/* required */
,
&
tensor_name
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
));
RETURN_IF_ERROR
(
GetBooleanSequenceControlProperties
(
batcher
,
config
.
name
(),
inference
::
ModelSequenceBatching
::
Control
::
CONTROL_SEQUENCE_END
,
false
/* required */
,
&
tensor_name
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
));
RETURN_IF_ERROR
(
GetBooleanSequenceControlProperties
(
batcher
,
config
.
name
(),
inference
::
ModelSequenceBatching
::
Control
::
CONTROL_SEQUENCE_READY
,
false
/* required */
,
&
tensor_name
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
));
// Check CORRID control and make sure it is one of the allowed types.
inference
::
DataType
tensor_datatype
;
RETURN_IF_ERROR
(
GetTypedSequenceControlProperties
(
batcher
,
config
.
name
(),
inference
::
ModelSequenceBatching
::
Control
::
CONTROL_SEQUENCE_CORRID
,
false
/* required */
,
&
tensor_name
,
&
tensor_datatype
));
if
(
!
tensor_name
.
empty
())
{
if
((
tensor_datatype
!=
inference
::
DataType
::
TYPE_UINT64
)
&&
(
tensor_datatype
!=
inference
::
DataType
::
TYPE_INT64
)
&&
(
tensor_datatype
!=
inference
::
DataType
::
TYPE_UINT32
)
&&
(
tensor_datatype
!=
inference
::
DataType
::
TYPE_INT32
)
&&
(
tensor_datatype
!=
inference
::
DataType
::
TYPE_STRING
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected data type for control "
+
inference
::
ModelSequenceBatching_Control_Kind_Name
(
inference
::
ModelSequenceBatching
::
Control
::
CONTROL_SEQUENCE_CORRID
)
+
" for "
+
config
.
name
()
+
". Allowed data types are TYPE_UINT64, TYPE_INT64, "
"TYPE_UINT32, "
"TYPE_INT32 and TYPE_STRING"
);
}
}
// If oldest-first strategy is enabled make sure the preferred
// batch sizes are positive and don't exceed maximum batch size.
if
(
config
.
sequence_batching
().
has_oldest
())
{
for
(
const
auto
size
:
config
.
sequence_batching
().
oldest
().
preferred_batch_size
())
{
if
(
size
<=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching preferred batch size must be positive for "
+
config
.
name
());
}
if
(
size
>
config
.
max_batch_size
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching preferred batch size must be <= max batch "
"size for "
+
config
.
name
());
}
}
}
// If direct strategy is enabled make sure the minimum slot utilization is
// in range (0.0, 1.0]
if
(
config
.
sequence_batching
().
has_direct
())
{
if
((
config
.
sequence_batching
().
direct
().
minimum_slot_utilization
()
<
0.0
)
||
(
config
.
sequence_batching
().
direct
().
minimum_slot_utilization
()
>
1.0
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"sequence batching minimum slot utilization must be in range "
"(0.0, 1.0] for "
+
config
.
name
());
}
}
}
// If ensemble scheduling is specified, validate it. Otherwise,
// must validate platform and instance_group
if
(
config
.
has_ensemble_scheduling
())
{
#ifdef TRITON_ENABLE_ENSEMBLE
RETURN_IF_ERROR
(
ValidateEnsembleSchedulingConfig
(
config
));
#else
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble scheduling not supported"
);
#endif // TRITON_ENABLE_ENSEMBLE
}
#ifdef TRITON_ENABLE_ENSEMBLE
else
if
(
config
.
platform
()
==
kEnsemblePlatform
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble scheduling must be set for ensemble "
+
config
.
name
()
+
" whose platform is "
+
kEnsemblePlatform
);
}
#endif // TRITON_ENABLE_ENSEMBLE
// FIXME: DLIS-3916 - Response Cache does not yet support decoupled models
if
(
config
.
model_transaction_policy
().
decoupled
()
&&
config
.
response_cache
().
enable
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Response Cache does not currently support model "
+
config
.
name
()
+
" with 'decoupled' transaction policy. Please disable the response"
" cache."
);
}
return
Status
::
Success
;
}
Status
ValidateInstanceGroup
(
const
inference
::
ModelConfig
&
config
,
const
double
min_compute_capability
)
{
// Instance group setting doesn't apply to ensemble
if
(
config
.
has_ensemble_scheduling
())
{
return
Status
::
Success
;
}
if
(
config
.
instance_group
().
size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"must specify one or more 'instance group's for "
+
config
.
name
());
}
// Make sure KIND_GPU instance group specifies at least one GPU and
// doesn't specify a non-existent GPU. Make sure non-KIND_GPU does
// not specify any GPUs.
#ifdef TRITON_ENABLE_GPU
std
::
set
<
int
>
supported_gpus
;
Status
status
=
GetSupportedGPUs
(
&
supported_gpus
,
min_compute_capability
);
if
(
!
status
.
IsOk
())
{
return
status
;
}
#endif // TRITON_ENABLE_GPU
for
(
const
auto
&
group
:
config
.
instance_group
())
{
if
(
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_MODEL
)
{
if
(
group
.
gpus
().
size
()
>
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" has kind KIND_MODEL but specifies one or more GPUs"
);
}
}
else
if
(
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_GPU
)
{
#if !defined(TRITON_ENABLE_GPU) && !defined(TRITON_ENABLE_MALI_GPU)
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" has kind KIND_GPU but server does not support GPUs"
);
#elif defined(TRITON_ENABLE_GPU)
if
(
group
.
gpus
().
size
()
==
0
)
{
if
(
supported_gpus
.
size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" has kind KIND_GPU but no GPUs are available"
);
}
else
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" has kind KIND_GPU but specifies no GPUs"
);
}
}
for
(
const
int32_t
gid
:
group
.
gpus
())
{
if
(
supported_gpus
.
find
(
gid
)
==
supported_gpus
.
end
())
{
std
::
string
supported_gpus_str
;
for
(
const
auto
&
cc
:
supported_gpus
)
{
if
(
!
supported_gpus_str
.
empty
())
{
supported_gpus_str
+=
", "
;
}
supported_gpus_str
+=
std
::
to_string
(
cc
);
}
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" specifies invalid or unsupported gpu id "
+
std
::
to_string
(
gid
)
+
". GPUs with at least the minimum required CUDA compute "
"compatibility of "
+
std
::
to_string
(
min_compute_capability
)
+
" are: "
+
supported_gpus_str
);
}
}
#endif // ! TRITON_ENABLE_GPU && ! TRITON_ENABLE_MALI_GPU
}
else
if
(
group
.
kind
()
==
inference
::
ModelInstanceGroup
::
KIND_CPU
)
{
if
(
group
.
gpus
().
size
()
>
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" has kind KIND_CPU but specifies one or more GPUs"
);
}
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" has unexpected kind KIND_AUTO"
);
}
if
((
config
.
platform
()
!=
kTensorRTPlanPlatform
)
&&
!
group
.
profile
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" and platform "
+
config
.
platform
()
+
"specifies profile field which is only supported for "
"TensorRT models"
);
}
else
if
(
!
group
.
profile
().
empty
())
{
for
(
const
auto
&
profile
:
group
.
profile
())
{
int
profile_index
;
RETURN_IF_ERROR
(
GetProfileIndex
(
profile
,
&
profile_index
));
if
(
profile_index
<
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"instance group "
+
group
.
name
()
+
" of model "
+
config
.
name
()
+
" and platform "
+
config
.
platform
()
+
" specifies invalid profile "
+
profile
+
". The field should contain the string representation of a "
"non-negative integer."
);
}
}
}
}
return
Status
::
Success
;
}
Status
ValidateModelInput
(
const
inference
::
ModelInput
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
platform
)
{
RETURN_IF_ERROR
(
ValidateIOShape
(
io
,
max_batch_size
,
"model input "
));
if
(((
io
.
format
()
==
inference
::
ModelInput
::
FORMAT_NHWC
)
||
(
io
.
format
()
==
inference
::
ModelInput
::
FORMAT_NCHW
))
&&
(
io
.
dims_size
()
!=
3
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"model input NHWC/NCHW require 3 dims"
);
}
if
((
platform
!=
kTensorRTPlanPlatform
)
&&
io
.
is_shape_tensor
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"shape tensors are only supported for TensorRT platform"
);
}
return
Status
::
Success
;
}
Status
CheckAllowedModelInput
(
const
inference
::
ModelInput
&
io
,
const
std
::
set
<
std
::
string
>&
allowed
)
{
if
(
allowed
.
find
(
io
.
name
())
==
allowed
.
end
())
{
std
::
string
astr
;
for
(
const
auto
&
a
:
allowed
)
{
if
(
!
astr
.
empty
())
{
astr
.
append
(
", "
);
}
astr
.
append
(
a
);
}
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected inference input '"
+
io
.
name
()
+
"', allowed inputs are: "
+
astr
);
}
return
Status
::
Success
;
}
Status
ValidateModelOutput
(
const
inference
::
ModelOutput
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
platform
)
{
RETURN_IF_ERROR
(
ValidateIOShape
(
io
,
max_batch_size
,
"model output "
));
if
((
platform
!=
kTensorRTPlanPlatform
)
&&
io
.
is_shape_tensor
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"shape tensors are only supported for TensorRT platform"
);
}
return
Status
::
Success
;
}
Status
CheckAllowedModelOutput
(
const
inference
::
ModelOutput
&
io
,
const
std
::
set
<
std
::
string
>&
allowed
)
{
if
(
allowed
.
find
(
io
.
name
())
==
allowed
.
end
())
{
std
::
string
astr
;
for
(
const
auto
&
a
:
allowed
)
{
if
(
!
astr
.
empty
())
{
astr
.
append
(
", "
);
}
astr
.
append
(
a
);
}
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected inference output '"
+
io
.
name
()
+
"', allowed outputs are: "
+
astr
);
}
return
Status
::
Success
;
}
Status
ParseBoolParameter
(
const
std
::
string
&
key
,
std
::
string
value
,
bool
*
parsed_value
)
{
std
::
transform
(
value
.
begin
(),
value
.
end
(),
value
.
begin
(),
[](
unsigned
char
c
)
{
return
std
::
tolower
(
c
);
});
if
((
value
==
"true"
)
||
(
value
==
"1"
))
{
*
parsed_value
=
true
;
}
else
if
((
value
==
"false"
)
||
(
value
==
"0"
))
{
*
parsed_value
=
false
;
}
else
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"failed to convert "
+
key
+
" '"
+
value
+
"' to boolean value"
);
}
return
Status
::
Success
;
}
Status
ParseLongLongParameter
(
const
std
::
string
&
key
,
const
std
::
string
&
value
,
int64_t
*
parsed_value
)
{
try
{
*
parsed_value
=
std
::
stoll
(
value
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"failed to convert "
+
key
+
" '"
+
value
+
"' to integral number"
);
}
return
Status
::
Success
;
}
Status
GetProfileIndex
(
const
std
::
string
&
profile_name
,
int
*
profile_index
)
{
if
(
profile_name
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"profile name must not be empty"
);
}
try
{
*
profile_index
=
stoi
(
profile_name
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unable to parse '"
+
profile_name
+
"': "
+
ia
.
what
());
}
return
Status
::
Success
;
}
namespace
{
Status
CollectInt64Fields
(
google
::
protobuf
::
Message
*
message
,
const
std
::
string
&
prefix
,
std
::
set
<
std
::
string
>*
int64_fields
)
{
const
google
::
protobuf
::
Descriptor
*
desc
=
message
->
GetDescriptor
();
const
google
::
protobuf
::
Reflection
*
refl
=
message
->
GetReflection
();
for
(
int
i
=
0
;
i
<
desc
->
field_count
();
++
i
)
{
const
google
::
protobuf
::
FieldDescriptor
*
field
=
desc
->
field
(
i
);
const
std
::
string
fullname
=
prefix
+
"::"
+
field
->
name
();
switch
(
field
->
type
())
{
case
google
::
protobuf
::
FieldDescriptor
::
TYPE_MESSAGE
:
{
if
(
field
->
is_repeated
())
{
int
rsize
=
refl
->
FieldSize
(
*
message
,
field
);
if
(
rsize
==
0
)
{
refl
->
AddMessage
(
message
,
field
);
}
rsize
=
refl
->
FieldSize
(
*
message
,
field
);
for
(
int
r
=
0
;
r
<
rsize
;
++
r
)
{
RETURN_IF_ERROR
(
CollectInt64Fields
(
refl
->
MutableRepeatedMessage
(
message
,
field
,
r
),
fullname
,
int64_fields
));
}
}
else
{
RETURN_IF_ERROR
(
CollectInt64Fields
(
refl
->
MutableMessage
(
message
,
field
),
fullname
,
int64_fields
));
}
}
break
;
case
google
::
protobuf
::
FieldDescriptor
::
TYPE_INT64
:
case
google
::
protobuf
::
FieldDescriptor
::
TYPE_UINT64
:
case
google
::
protobuf
::
FieldDescriptor
::
TYPE_SINT64
:
case
google
::
protobuf
::
FieldDescriptor
::
TYPE_FIXED64
:
case
google
::
protobuf
::
FieldDescriptor
::
TYPE_SFIXED64
:
int64_fields
->
insert
(
fullname
);
break
;
default:
break
;
}
}
return
Status
::
Success
;
}
Status
ValidateModelConfigInt64
()
{
// Must initialize a dummy ModelConfig so that all fields are
// visited.
inference
::
ModelConfig
config
;
std
::
set
<
std
::
string
>
int64_fields
;
RETURN_IF_ERROR
(
CollectInt64Fields
(
&
config
,
"ModelConfig"
,
&
int64_fields
));
LOG_VERBOSE
(
1
)
<<
"ModelConfig 64-bit fields:"
;
for
(
const
auto
&
f
:
int64_fields
)
{
LOG_VERBOSE
(
1
)
<<
"
\t
"
<<
f
;
}
// We expect to find exactly the following fields. If we get an
// error from this code ModelConfig has added or removed a 64-bit
// field and we need to adjust here and in ModelConfigToJson below.
std
::
set
<
std
::
string
>
expected
{
"ModelConfig::input::dims"
,
"ModelConfig::input::reshape::shape"
,
"ModelConfig::output::dims"
,
"ModelConfig::output::reshape::shape"
,
"ModelConfig::version_policy::specific::versions"
,
"ModelConfig::dynamic_batching::max_queue_delay_microseconds"
,
"ModelConfig::dynamic_batching::default_queue_policy::default_timeout_"
"microseconds"
,
"ModelConfig::dynamic_batching::priority_queue_policy::value::default_"
"timeout_microseconds"
,
"ModelConfig::sequence_batching::direct::max_queue_delay_microseconds"
,
"ModelConfig::sequence_batching::state::dims"
,
"ModelConfig::sequence_batching::state::initial_state::dims"
,
"ModelConfig::sequence_batching::oldest::max_queue_delay_microseconds"
,
"ModelConfig::sequence_batching::max_sequence_idle_microseconds"
,
"ModelConfig::ensemble_scheduling::step::model_version"
,
"ModelConfig::model_warmup::inputs::value::dims"
,
"ModelConfig::optimization::cuda::graph_spec::input::value::dim"
,
"ModelConfig::optimization::cuda::graph_spec::graph_lower_bound::input::"
"value::dim"
,
"ModelConfig::instance_group::secondary_devices::device_id"
};
if
(
int64_fields
!=
expected
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"ModelConfig 64-bit field needs update"
);
}
return
Status
::
Success
;
}
Status
FixInt
(
triton
::
common
::
TritonJson
::
Value
&
document
,
triton
::
common
::
TritonJson
::
Value
&
io
,
const
std
::
string
&
name
)
{
triton
::
common
::
TritonJson
::
Value
str_value
;
if
(
!
io
.
Find
(
name
.
c_str
(),
&
str_value
))
{
return
Status
::
Success
;
}
std
::
string
str
;
RETURN_IF_ERROR
(
str_value
.
AsString
(
&
str
));
int64_t
d
;
try
{
d
=
std
::
atoll
(
str
.
c_str
());
}
catch
(...)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
(
std
::
string
(
"unable to convert '"
)
+
str
+
"' to integer"
));
}
str_value
.
SetInt
(
d
);
return
Status
::
Success
;
}
Status
FixIntArray
(
triton
::
common
::
TritonJson
::
Value
&
document
,
triton
::
common
::
TritonJson
::
Value
&
io
,
const
std
::
string
&
name
)
{
triton
::
common
::
TritonJson
::
Value
fixed_shape_array
(
document
,
triton
::
common
::
TritonJson
::
ValueType
::
ARRAY
);
if
(
!
io
.
Find
(
name
.
c_str
()))
{
return
Status
::
Success
;
}
triton
::
common
::
TritonJson
::
Value
shape_array
;
RETURN_IF_ERROR
(
io
.
MemberAsArray
(
name
.
c_str
(),
&
shape_array
));
for
(
size_t
i
=
0
;
i
<
shape_array
.
ArraySize
();
++
i
)
{
std
::
string
str
;
RETURN_IF_ERROR
(
shape_array
.
IndexAsString
(
i
,
&
str
));
int64_t
d
;
try
{
d
=
std
::
atoll
(
str
.
c_str
());
}
catch
(...)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
(
std
::
string
(
"unable to convert '"
)
+
str
+
"' to integer"
));
}
RETURN_IF_ERROR
(
fixed_shape_array
.
AppendInt
(
d
));
}
shape_array
.
Swap
(
fixed_shape_array
);
fixed_shape_array
.
Release
();
return
Status
::
Success
;
}
Status
FixObjectArray
(
triton
::
common
::
TritonJson
::
Value
&
document
,
triton
::
common
::
TritonJson
::
Value
&
arr
,
const
std
::
string
&
name
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
ArraySize
();
++
i
)
{
triton
::
common
::
TritonJson
::
Value
obj
;
RETURN_IF_ERROR
(
arr
.
IndexAsObject
(
i
,
&
obj
));
RETURN_IF_ERROR
(
FixInt
(
document
,
obj
,
name
));
}
return
Status
::
Success
;
}
}
// namespace
Status
ModelConfigToJson
(
const
inference
::
ModelConfig
&
config
,
const
uint32_t
config_version
,
std
::
string
*
json_str
)
{
// Currently only support 'config_version' 1, which is the json
// representation of the ModelConfig protobuf with the int64 fields
// fixes to be actual numbers instead of the string madness done by
// protobuf.
if
(
config_version
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
std
::
string
(
"model configuration version "
)
+
std
::
to_string
(
config_version
)
+
" not supported, supported versions are: 1"
);
}
// Config will have 0 byte size if all fields are with default value,
// in other word the config is empty.
if
(
config
.
ByteSizeLong
()
==
0
)
{
json_str
->
clear
();
return
Status
::
Success
;
}
std
::
string
config_json_str
;
::
google
::
protobuf
::
util
::
JsonPrintOptions
options
;
options
.
preserve_proto_field_names
=
true
;
options
.
always_print_primitive_fields
=
true
;
::
google
::
protobuf
::
util
::
MessageToJsonString
(
config
,
&
config_json_str
,
options
);
// We need to verify that every field 64-bit field in the
// ModelConfig protobuf is being handled. We hardcode the known
// fields and check just once to make sure everything has been
// handled. We could have this check in a separately compiled CI
// test but it is convenient to keep it here close to the code below
// that actually fixes the 64-bit fields.
{
static
std
::
once_flag
fonce
;
Status
status
=
Status
::
Success
;
std
::
call_once
(
fonce
,
[
&
status
]
{
status
=
ValidateModelConfigInt64
();
});
RETURN_IF_ERROR
(
status
);
}
// In the json produced by protobuf, int64 and uint64 values are
// represented as strings. Protobuf doesn't provide an option to
// disable this (sigh) so we need to fix it up here as we want the
// json representation of the config to be reasonable json...
triton
::
common
::
TritonJson
::
Value
config_json
;
config_json
.
Parse
(
config_json_str
);
// Fix input::dims, input::reshape::shape, output::dims,
// output::reshape::shape
for
(
std
::
string
name
:
{
"input"
,
"output"
})
{
triton
::
common
::
TritonJson
::
Value
ios
;
RETURN_IF_ERROR
(
config_json
.
MemberAsArray
(
name
.
c_str
(),
&
ios
));
for
(
size_t
i
=
0
;
i
<
ios
.
ArraySize
();
++
i
)
{
triton
::
common
::
TritonJson
::
Value
io
;
RETURN_IF_ERROR
(
ios
.
IndexAsObject
(
i
,
&
io
));
RETURN_IF_ERROR
(
FixIntArray
(
config_json
,
io
,
"dims"
));
triton
::
common
::
TritonJson
::
Value
reshape
;
if
(
io
.
Find
(
"reshape"
,
&
reshape
))
{
RETURN_IF_ERROR
(
FixIntArray
(
config_json
,
reshape
,
"shape"
));
}
}
}
// Fix version_policy::specific::versions
{
triton
::
common
::
TritonJson
::
Value
vp
;
if
(
config_json
.
Find
(
"version_policy"
,
&
vp
))
{
triton
::
common
::
TritonJson
::
Value
specific
;
if
(
vp
.
Find
(
"specific"
,
&
specific
))
{
RETURN_IF_ERROR
(
FixIntArray
(
config_json
,
specific
,
"versions"
));
}
}
}
// Fix dynamic_batching::max_queue_delay_microseconds,
// dynamic_batching::default_queue_policy::default_timeout_microseconds,
// dynamic_batching::priority_queue_policy::value::default_timeout_microseconds
{
triton
::
common
::
TritonJson
::
Value
db
;
if
(
config_json
.
Find
(
"dynamic_batching"
,
&
db
))
{
RETURN_IF_ERROR
(
FixInt
(
config_json
,
db
,
"max_queue_delay_microseconds"
));
triton
::
common
::
TritonJson
::
Value
dqp
;
if
(
db
.
Find
(
"default_queue_policy"
,
&
dqp
))
{
RETURN_IF_ERROR
(
FixInt
(
config_json
,
dqp
,
"default_timeout_microseconds"
));
}
triton
::
common
::
TritonJson
::
Value
pqp
;
if
(
db
.
Find
(
"priority_queue_policy"
,
&
pqp
))
{
// Iterate over each member in 'pqp' and fix...
std
::
vector
<
std
::
string
>
members
;
RETURN_IF_ERROR
(
pqp
.
Members
(
&
members
));
for
(
const
auto
&
m
:
members
)
{
triton
::
common
::
TritonJson
::
Value
el
;
RETURN_IF_ERROR
(
pqp
.
MemberAsObject
(
m
.
c_str
(),
&
el
));
RETURN_IF_ERROR
(
FixInt
(
config_json
,
el
,
"default_timeout_microseconds"
));
}
}
}
}
// Fix sequence_batching::oldest::max_queue_delay_microseconds,
// sequence_batching::direct::max_queue_delay_microseconds,
// sequence_batching::max_sequence_idle_microseconds
{
triton
::
common
::
TritonJson
::
Value
sb
;
if
(
config_json
.
Find
(
"sequence_batching"
,
&
sb
))
{
RETURN_IF_ERROR
(
FixInt
(
config_json
,
sb
,
"max_sequence_idle_microseconds"
));
triton
::
common
::
TritonJson
::
Value
oldest
;
if
(
sb
.
Find
(
"oldest"
,
&
oldest
))
{
RETURN_IF_ERROR
(
FixInt
(
config_json
,
oldest
,
"max_queue_delay_microseconds"
));
}
triton
::
common
::
TritonJson
::
Value
direct
;
if
(
sb
.
Find
(
"direct"
,
&
direct
))
{
RETURN_IF_ERROR
(
FixInt
(
config_json
,
direct
,
"max_queue_delay_microseconds"
));
}
triton
::
common
::
TritonJson
::
Value
states
;
if
(
sb
.
Find
(
"state"
,
&
states
))
{
for
(
size_t
i
=
0
;
i
<
states
.
ArraySize
();
++
i
)
{
triton
::
common
::
TritonJson
::
Value
state
;
RETURN_IF_ERROR
(
states
.
IndexAsObject
(
i
,
&
state
));
RETURN_IF_ERROR
(
FixIntArray
(
config_json
,
state
,
"dims"
));
triton
::
common
::
TritonJson
::
Value
initial_state
;
if
(
sb
.
Find
(
"initial_state"
,
&
initial_state
))
{
RETURN_IF_ERROR
(
FixIntArray
(
config_json
,
initial_state
,
"dims"
));
}
}
}
}
}
// Fix ensemble_scheduling::step::model_version.
{
triton
::
common
::
TritonJson
::
Value
ens
;
if
(
config_json
.
Find
(
"ensemble_scheduling"
,
&
ens
))
{
triton
::
common
::
TritonJson
::
Value
step
;
if
(
ens
.
Find
(
"step"
,
&
step
))
{
RETURN_IF_ERROR
(
FixObjectArray
(
config_json
,
step
,
"model_version"
));
}
}
}
// Fix model_warmup::inputs::value::dims.
{
triton
::
common
::
TritonJson
::
Value
warmups
;
if
(
config_json
.
Find
(
"model_warmup"
,
&
warmups
))
{
for
(
size_t
i
=
0
;
i
<
warmups
.
ArraySize
();
++
i
)
{
triton
::
common
::
TritonJson
::
Value
warmup
;
RETURN_IF_ERROR
(
warmups
.
IndexAsObject
(
i
,
&
warmup
));
triton
::
common
::
TritonJson
::
Value
inputs
;
if
(
warmup
.
Find
(
"inputs"
,
&
inputs
))
{
std
::
vector
<
std
::
string
>
members
;
RETURN_IF_ERROR
(
inputs
.
Members
(
&
members
));
for
(
const
auto
&
m
:
members
)
{
triton
::
common
::
TritonJson
::
Value
input
;
RETURN_IF_ERROR
(
inputs
.
MemberAsObject
(
m
.
c_str
(),
&
input
));
RETURN_IF_ERROR
(
FixIntArray
(
config_json
,
input
,
"dims"
));
}
}
}
}
}
// Convert fixed json back the string...
triton
::
common
::
TritonJson
::
WriteBuffer
buffer
;
RETURN_IF_ERROR
(
config_json
.
Write
(
&
buffer
));
*
json_str
=
std
::
move
(
buffer
.
MutableContents
());
return
Status
::
Success
;
}
Status
JsonToModelConfig
(
const
std
::
string
&
json_config
,
const
uint32_t
config_version
,
inference
::
ModelConfig
*
protobuf_config
)
{
// Currently only support 'config_version' 1, which is the json
// representation of the ModelConfig protobuf matches the representation in
// ModelConfigToJson().
if
(
config_version
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
std
::
string
(
"model configuration version "
)
+
std
::
to_string
(
config_version
)
+
" not supported, supported versions are: 1"
);
}
::
google
::
protobuf
::
util
::
JsonParseOptions
options
;
options
.
case_insensitive_enum_parsing
=
true
;
options
.
ignore_unknown_fields
=
false
;
auto
err
=
::
google
::
protobuf
::
util
::
JsonStringToMessage
(
json_config
,
protobuf_config
,
options
);
if
(
!
err
.
ok
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
std
::
string
(
err
.
message
()));
}
return
Status
::
Success
;
}
BackendType
GetBackendTypeFromPlatform
(
const
std
::
string
&
platform_name
)
{
if
((
platform_name
==
kTensorFlowGraphDefPlatform
)
||
(
platform_name
==
kTensorFlowSavedModelPlatform
))
{
return
BackendType
::
BACKEND_TYPE_TENSORFLOW
;
}
if
(
platform_name
==
kTensorRTPlanPlatform
)
{
return
BackendType
::
BACKEND_TYPE_TENSORRT
;
}
if
(
platform_name
==
kOnnxRuntimeOnnxPlatform
)
{
return
BackendType
::
BACKEND_TYPE_ONNXRUNTIME
;
}
if
(
platform_name
==
kPyTorchLibTorchPlatform
)
{
return
BackendType
::
BACKEND_TYPE_PYTORCH
;
}
return
BackendType
::
BACKEND_TYPE_UNKNOWN
;
}
/// Get the BackendType value for a backend name.
/// \param backend_name The backend name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType
GetBackendType
(
const
std
::
string
&
backend_name
)
{
if
(
backend_name
==
kTensorFlowBackend
)
{
return
BackendType
::
BACKEND_TYPE_TENSORFLOW
;
}
if
(
backend_name
==
kTensorRTBackend
)
{
return
BackendType
::
BACKEND_TYPE_TENSORRT
;
}
if
(
backend_name
==
kOnnxRuntimeBackend
)
{
return
BackendType
::
BACKEND_TYPE_ONNXRUNTIME
;
}
if
(
backend_name
==
kPyTorchBackend
)
{
return
BackendType
::
BACKEND_TYPE_PYTORCH
;
}
return
BackendType
::
BACKEND_TYPE_UNKNOWN
;
}
TRITONSERVER_DataType
DataTypeToTriton
(
const
inference
::
DataType
dtype
)
{
switch
(
dtype
)
{
case
inference
::
DataType
::
TYPE_BOOL
:
return
TRITONSERVER_TYPE_BOOL
;
case
inference
::
DataType
::
TYPE_UINT8
:
return
TRITONSERVER_TYPE_UINT8
;
case
inference
::
DataType
::
TYPE_UINT16
:
return
TRITONSERVER_TYPE_UINT16
;
case
inference
::
DataType
::
TYPE_UINT32
:
return
TRITONSERVER_TYPE_UINT32
;
case
inference
::
DataType
::
TYPE_UINT64
:
return
TRITONSERVER_TYPE_UINT64
;
case
inference
::
DataType
::
TYPE_INT8
:
return
TRITONSERVER_TYPE_INT8
;
case
inference
::
DataType
::
TYPE_INT16
:
return
TRITONSERVER_TYPE_INT16
;
case
inference
::
DataType
::
TYPE_INT32
:
return
TRITONSERVER_TYPE_INT32
;
case
inference
::
DataType
::
TYPE_INT64
:
return
TRITONSERVER_TYPE_INT64
;
case
inference
::
DataType
::
TYPE_FP16
:
return
TRITONSERVER_TYPE_FP16
;
case
inference
::
DataType
::
TYPE_FP32
:
return
TRITONSERVER_TYPE_FP32
;
case
inference
::
DataType
::
TYPE_FP64
:
return
TRITONSERVER_TYPE_FP64
;
case
inference
::
DataType
::
TYPE_STRING
:
return
TRITONSERVER_TYPE_BYTES
;
case
inference
::
DataType
::
TYPE_BF16
:
return
TRITONSERVER_TYPE_BF16
;
default:
break
;
}
return
TRITONSERVER_TYPE_INVALID
;
}
inference
::
DataType
TritonToDataType
(
const
TRITONSERVER_DataType
dtype
)
{
switch
(
dtype
)
{
case
TRITONSERVER_TYPE_BOOL
:
return
inference
::
DataType
::
TYPE_BOOL
;
case
TRITONSERVER_TYPE_UINT8
:
return
inference
::
DataType
::
TYPE_UINT8
;
case
TRITONSERVER_TYPE_UINT16
:
return
inference
::
DataType
::
TYPE_UINT16
;
case
TRITONSERVER_TYPE_UINT32
:
return
inference
::
DataType
::
TYPE_UINT32
;
case
TRITONSERVER_TYPE_UINT64
:
return
inference
::
DataType
::
TYPE_UINT64
;
case
TRITONSERVER_TYPE_INT8
:
return
inference
::
DataType
::
TYPE_INT8
;
case
TRITONSERVER_TYPE_INT16
:
return
inference
::
DataType
::
TYPE_INT16
;
case
TRITONSERVER_TYPE_INT32
:
return
inference
::
DataType
::
TYPE_INT32
;
case
TRITONSERVER_TYPE_INT64
:
return
inference
::
DataType
::
TYPE_INT64
;
case
TRITONSERVER_TYPE_FP16
:
return
inference
::
DataType
::
TYPE_FP16
;
case
TRITONSERVER_TYPE_FP32
:
return
inference
::
DataType
::
TYPE_FP32
;
case
TRITONSERVER_TYPE_FP64
:
return
inference
::
DataType
::
TYPE_FP64
;
case
TRITONSERVER_TYPE_BYTES
:
return
inference
::
DataType
::
TYPE_STRING
;
case
TRITONSERVER_TYPE_BF16
:
return
inference
::
DataType
::
TYPE_BF16
;
default:
break
;
}
return
inference
::
DataType
::
TYPE_INVALID
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_config_utils.h
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "model_config.pb.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
#include "filesystem.h"
namespace
triton
{
namespace
core
{
/// Enumeration for the different backend types.
enum
BackendType
{
BACKEND_TYPE_UNKNOWN
=
0
,
BACKEND_TYPE_TENSORRT
=
1
,
BACKEND_TYPE_TENSORFLOW
=
2
,
BACKEND_TYPE_ONNXRUNTIME
=
3
,
BACKEND_TYPE_PYTORCH
=
4
};
// Get version of a model from the path containing the model
/// definition file.
/// \param path The path to the model definition file.
/// \param version Returns the version.
/// \return The error status.
Status
GetModelVersionFromPath
(
const
std
::
string
&
path
,
int64_t
*
version
);
/// Get the tensor name, false value, and true value for a boolean
/// sequence batcher control kind. If 'required' is true then must
/// find a tensor for the control. If 'required' is false, return
/// 'tensor_name' as empty-string if the control is not mapped to any
/// tensor.
Status
GetBooleanSequenceControlProperties
(
const
inference
::
ModelSequenceBatching
&
batcher
,
const
std
::
string
&
model_name
,
const
inference
::
ModelSequenceBatching
::
Control
::
Kind
control_kind
,
const
bool
required
,
std
::
string
*
tensor_name
,
inference
::
DataType
*
tensor_datatype
,
float
*
fp32_false_value
,
float
*
fp32_true_value
,
int32_t
*
int32_false_value
,
int32_t
*
int32_true_value
,
bool
*
bool_false_value
,
bool
*
bool_true_value
);
/// Get the tensor name and datatype for a non-boolean sequence
/// batcher control kind. If 'required' is true then must find a
/// tensor for the control. If 'required' is false, return
/// 'tensor_name' as empty-string if the control is not mapped to any
/// tensor. 'tensor_datatype' returns the required datatype for the
/// control.
Status
GetTypedSequenceControlProperties
(
const
inference
::
ModelSequenceBatching
&
batcher
,
const
std
::
string
&
model_name
,
const
inference
::
ModelSequenceBatching
::
Control
::
Kind
control_kind
,
const
bool
required
,
std
::
string
*
tensor_name
,
inference
::
DataType
*
tensor_datatype
);
/// Read a ModelConfig and normalize it as expected by model backends.
/// \param path The full-path to the directory containing the
/// model configuration.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \param config Returns the normalized model configuration.
/// \return The error status.
Status
GetNormalizedModelConfig
(
const
std
::
string
&
model_name
,
const
std
::
string
&
path
,
const
double
min_compute_capability
,
inference
::
ModelConfig
*
config
);
/// Auto-complete backend related fields (platform, backend and default model
/// filename) if not set, note that only Triton recognized backends will be
/// checked.
/// \param model_name The name of the model.
/// \param model_path The full-path to the directory containing the
/// model configuration.
/// \param config Returns the auto-completed model configuration.
/// \return The error status.
Status
AutoCompleteBackendFields
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
inference
::
ModelConfig
*
config
);
/// Detects and adds missing fields in the model configuration.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
/// \param config The model configuration
/// \return The error status
Status
NormalizeModelConfig
(
const
double
min_compute_capability
,
inference
::
ModelConfig
*
config
);
/// [FIXME] better formalize config normalization / validation
/// Detects and adds missing fields in instance group setting.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
/// \param config The model configuration
/// \return The error status
Status
NormalizeInstanceGroup
(
const
double
min_compute_capability
,
const
std
::
vector
<
inference
::
ModelInstanceGroup
>&
preferred_groups
,
inference
::
ModelConfig
*
config
);
/// [FIXME] Remove once a more permanent solution is implemented (DLIS-4211)
/// Localize EXECUTION_ENV_PATH in python backend.
/// \param model_path The full-path to the directory containing the model
/// configuration, before localization.
/// \param config The model configuration
/// \param localized_model_dir The localized model directory
/// \return The error status
Status
LocalizePythonBackendExecutionEnvironmentPath
(
const
std
::
string
&
model_path
,
inference
::
ModelConfig
*
config
,
std
::
shared_ptr
<
LocalizedPath
>*
localized_model_dir
);
/// Auto-complete the instance count based on instance kind and backend name.
/// \param group The instance group to set the count for.
/// \param backend The backend name to check against.
/// \return The error status.
Status
SetDefaultInstanceCount
(
inference
::
ModelInstanceGroup
*
group
,
const
std
::
string
&
backend
);
/// Validate that a model is specified correctly, except for model inputs
/// and outputs. ValidateModelIOConfig() should be called to
/// validate model inputs and outputs.
/// \param config The model configuration to validate.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status
ValidateModelConfig
(
const
inference
::
ModelConfig
&
config
,
const
double
min_compute_capability
);
/// [FIXME] better formalize config normalization / validation
/// Validate instance group setting.
/// \param config The model configuration to validate.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status
ValidateInstanceGroup
(
const
inference
::
ModelConfig
&
config
,
const
double
min_compute_capability
);
/// Validate that a model inputs and outputs are specified correctly.
/// \param config The model configuration to validate.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status
ValidateModelIOConfig
(
const
inference
::
ModelConfig
&
config
);
/// Validate that input is specified correctly in a model
/// configuration.
/// \param io The model input.
/// \param max_batch_size The max batch size specified in model configuration.
/// \param platform The platform name
/// \return The error status. A non-OK status indicates the input
/// is not valid.
Status
ValidateModelInput
(
const
inference
::
ModelInput
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
platform
);
/// Validate that an input matches one of the allowed input names.
/// \param io The model input.
/// \param allowed The set of allowed input names.
/// \return The error status. A non-OK status indicates the input
/// is not valid.
Status
CheckAllowedModelInput
(
const
inference
::
ModelInput
&
io
,
const
std
::
set
<
std
::
string
>&
allowed
);
/// Validate that an output is specified correctly in a model
/// configuration.
/// \param io The model output.
/// \param max_batch_size The max batch size specified in model configuration.
/// \param platform The platform name
/// \return The error status. A non-OK status indicates the output
/// is not valid.
Status
ValidateModelOutput
(
const
inference
::
ModelOutput
&
io
,
int32_t
max_batch_size
,
const
std
::
string
&
platform
);
/// Validate that an output matches one of the allowed output names.
/// \param io The model output.
/// \param allowed The set of allowed output names.
/// \return The error status. A non-OK status indicates the output
/// is not valid.
Status
CheckAllowedModelOutput
(
const
inference
::
ModelOutput
&
io
,
const
std
::
set
<
std
::
string
>&
allowed
);
/// Validate that a model batch inputs and batch outputs are specified
/// correctly.
/// \param config The model configuration to validate..
/// \return The error status. A non-OK status indicates the batch inputs or
/// batch outputs are not valid.
Status
ValidateBatchIO
(
const
inference
::
ModelConfig
&
config
);
/// Parse the 'value' of the parameter 'key' into a boolean value.
/// \param key The name of the parameter.
/// \param value The value of the parameter in string.
/// \param parsed_value Return the boolean of the parameter.
/// \return The error status. A non-OK status indicates failure on parsing the
/// value.
Status
ParseBoolParameter
(
const
std
::
string
&
key
,
std
::
string
value
,
bool
*
parsed_value
);
/// Parse the 'value' of the parameter 'key' into a long long integer value.
/// \param key The name of the parameter.
/// \param value The value of the parameter in string.
/// \param parsed_value Return the numerical value of the parameter.
/// \return The error status. A non-OK status indicates failure on parsing the
/// value.
Status
ParseLongLongParameter
(
const
std
::
string
&
key
,
const
std
::
string
&
value
,
int64_t
*
parsed_value
);
/// Obtain the 'profile_index' of the 'profile_name'.
/// \param profile_name The name of the profile.
/// \param profile_index Return the index of the profile.
/// \return The error status. A non-OK status indicates failure on getting the
/// value.
Status
GetProfileIndex
(
const
std
::
string
&
profile_name
,
int
*
profile_index
);
/// Convert a model configuration protobuf to the equivalent json.
/// \param config The protobuf model configuration.
/// \param config_version The model configuration will be returned in
/// a format matching this version. If the configuration cannot be
/// represented in the requested version's format then an error will
/// be returned.
/// \param json Returns the equivalent JSON.
/// \return The error status.
Status
ModelConfigToJson
(
const
inference
::
ModelConfig
&
config
,
const
uint32_t
config_version
,
std
::
string
*
json_str
);
/// Convert a model configuration JSON to the equivalent protobuf.
/// \param config The JSON model configuration.
/// \param config_version The model configuration will be returned in
/// a format matching this version. If the configuration cannot be
/// represented in the requested version's format then an error will
/// be returned.
/// \param protobuf Returns the equivalent protobuf.
/// \return The error status.
Status
JsonToModelConfig
(
const
std
::
string
&
json_config
,
const
uint32_t
config_version
,
inference
::
ModelConfig
*
protobuf_config
);
/// Get the BackendType value for a platform name.
/// \param platform_name The platform name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType
GetBackendTypeFromPlatform
(
const
std
::
string
&
platform_name
);
/// Get the BackendType value for a backend name.
/// \param backend_name The backend name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType
GetBackendType
(
const
std
::
string
&
backend_name
);
/// Get the Triton server data type corresponding to a data type.
/// \param dtype The data type.
/// \return The Triton server data type.
TRITONSERVER_DataType
DataTypeToTriton
(
const
inference
::
DataType
dtype
);
/// Get the data type corresponding to a Triton server data type.
/// \param dtype The Triton server data type.
/// \return The data type.
inference
::
DataType
TritonToDataType
(
const
TRITONSERVER_DataType
dtype
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_lifecycle.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_lifecycle.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "repo_agent.h"
#include "triton/common/logging.h"
#include "triton/common/thread_pool.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace
triton
{
namespace
core
{
const
std
::
string
&
ModelReadyStateString
(
ModelReadyState
state
)
{
switch
(
state
)
{
case
ModelReadyState
::
UNKNOWN
:
{
static
std
::
string
m
(
"UNKNOWN"
);
return
m
;
}
case
ModelReadyState
::
READY
:
{
static
std
::
string
m
(
"READY"
);
return
m
;
}
case
ModelReadyState
::
UNAVAILABLE
:
{
static
std
::
string
m
(
"UNAVAILABLE"
);
return
m
;
}
case
ModelReadyState
::
LOADING
:
{
static
std
::
string
m
(
"LOADING"
);
return
m
;
}
case
ModelReadyState
::
UNLOADING
:
{
static
std
::
string
m
(
"UNLOADING"
);
return
m
;
}
}
static
std
::
string
m
(
"<unknown>"
);
return
m
;
}
namespace
{
Status
VersionsToLoad
(
const
std
::
string
model_path
,
const
std
::
string
&
name
,
const
inference
::
ModelConfig
&
model_config
,
std
::
set
<
int64_t
>*
versions
)
{
versions
->
clear
();
// Get integral number of the version directory
std
::
set
<
std
::
string
>
subdirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
model_path
,
&
subdirs
));
std
::
set
<
int64_t
,
std
::
greater
<
int64_t
>>
existing_versions
;
for
(
const
auto
&
subdir
:
subdirs
)
{
if
(
subdir
==
kWarmupDataFolder
||
subdir
==
kInitialStateFolder
)
{
continue
;
}
if
((
subdir
.
length
()
>
1
)
&&
(
subdir
.
front
()
==
'0'
))
{
LOG_WARNING
<<
"ignore version directory '"
<<
subdir
<<
"' which contains leading zeros in its directory name"
;
continue
;
}
try
{
int64_t
version
=
std
::
stoll
(
subdir
);
existing_versions
.
insert
(
version
);
}
catch
(
const
std
::
invalid_argument
&
ia
)
{
LOG_WARNING
<<
"ignore version directory '"
<<
subdir
<<
"' which fails to convert to integral number"
;
}
}
if
(
model_config
.
version_policy
().
has_specific
())
{
for
(
const
auto
&
v
:
model_config
.
version_policy
().
specific
().
versions
())
{
// Only load the specific versions that are presented in model directory
bool
version_not_exist
=
existing_versions
.
insert
(
v
).
second
;
if
(
!
version_not_exist
)
{
versions
->
emplace
(
v
);
}
else
{
LOG_ERROR
<<
"version "
<<
v
<<
" is specified for model '"
<<
name
<<
"', but the version directory is not present"
;
}
}
}
else
{
if
(
model_config
.
version_policy
().
has_latest
())
{
// std::set is sorted with std::greater
for
(
const
auto
&
v
:
existing_versions
)
{
if
(
versions
->
size
()
>=
model_config
.
version_policy
().
latest
().
num_versions
())
{
break
;
}
versions
->
emplace
(
v
);
}
}
else
{
// all
versions
->
insert
(
existing_versions
.
begin
(),
existing_versions
.
end
());
}
}
return
Status
::
Success
;
}
// Use smart pointer with custom deleter so that model state will be updated
// to UNAVAILABLE if all smart pointer copies are out of scope
struct
ModelDeleter
{
ModelDeleter
(
std
::
function
<
void
()
>
OnDestroyModel
)
:
OnDestroyModel_
(
std
::
move
(
OnDestroyModel
))
{
}
void
operator
()(
Model
*
model
)
{
// The actual model object must be destroyed in a different
// thread. This thread could have a callstack that includes the
// model itself because this deleter could be triggered by
// a request release or response send in the model. Following
// delete will lead to the model destructor which may wait on this
// same thread... so deadlock if we don't use a different thread
// here.
std
::
function
<
void
()
>
destroy_fn
=
OnDestroyModel_
;
std
::
thread
dthd
([
model
,
destroy_fn
]()
{
delete
model
;
destroy_fn
();
});
dthd
.
detach
();
}
// Use to inform the ModelLifeCycle that the model handle is destroyed
std
::
function
<
void
()
>
OnDestroyModel_
;
};
}
// namespace
Status
ModelLifeCycle
::
Create
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
,
std
::
unique_ptr
<
ModelLifeCycle
>*
life_cycle
)
{
std
::
unique_ptr
<
ModelLifeCycle
>
local_life_cycle
(
new
ModelLifeCycle
(
server
,
options
));
*
life_cycle
=
std
::
move
(
local_life_cycle
);
return
Status
::
Success
;
}
const
ModelStateMap
ModelLifeCycle
::
LiveModelStates
(
bool
strict_readiness
)
{
LOG_VERBOSE
(
2
)
<<
"LiveModelStates()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
ModelStateMap
live_model_states
;
for
(
auto
&
model_version
:
map_
)
{
bool
live
=
false
;
VersionStateMap
version_map
;
for
(
auto
&
version_model
:
model_version
.
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
strict_readiness
&&
version_model
.
second
->
state_
!=
ModelReadyState
::
READY
)
{
continue
;
}
// At least one version is live (ready / loading / unloading)
if
((
version_model
.
second
->
state_
!=
ModelReadyState
::
UNKNOWN
)
&&
(
version_model
.
second
->
state_
!=
ModelReadyState
::
UNAVAILABLE
))
{
live
=
true
;
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
}
if
(
live
)
{
live_model_states
[
model_version
.
first
]
=
std
::
move
(
version_map
);
}
}
return
live_model_states
;
}
Status
ModelLifeCycle
::
StopAllModels
()
{
LOG_VERBOSE
(
2
)
<<
"StopAllModels()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
for
(
auto
&
model_version
:
map_
)
{
for
(
auto
&
version_model
:
model_version
.
second
)
{
if
(
version_model
.
second
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
model_
!=
nullptr
)
{
version_model
.
second
->
model_
->
Stop
();
}
}
}
}
return
Status
::
Success
;
}
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
ModelLifeCycle
::
InflightStatus
()
{
LOG_VERBOSE
(
2
)
<<
"InflightStatus()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
inflight_status
;
for
(
auto
&
model_version
:
map_
)
{
for
(
auto
&
version_model
:
model_version
.
second
)
{
if
(
version_model
.
second
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
model_
!=
nullptr
)
{
const
auto
cnt
=
version_model
.
second
->
model_
->
InflightInferenceCount
();
if
(
cnt
!=
0
)
{
inflight_status
.
emplace
(
model_version
.
first
,
version_model
.
first
,
cnt
);
}
}
}
}
}
return
inflight_status
;
}
const
ModelStateMap
ModelLifeCycle
::
ModelStates
()
{
LOG_VERBOSE
(
2
)
<<
"ModelStates()"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
ModelStateMap
model_states
;
for
(
auto
&
model_version
:
map_
)
{
VersionStateMap
version_map
;
for
(
auto
&
version_model
:
model_version
.
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
model_states
[
model_version
.
first
]
=
std
::
move
(
version_map
);
}
return
model_states
;
}
const
VersionStateMap
ModelLifeCycle
::
VersionStates
(
const
std
::
string
&
model_name
)
{
LOG_VERBOSE
(
2
)
<<
"VersionStates() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
VersionStateMap
version_map
;
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
!=
map_
.
end
())
{
for
(
auto
&
version_model
:
mit
->
second
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
version_map
[
version_model
.
first
]
=
std
::
make_pair
(
version_model
.
second
->
state_
,
version_model
.
second
->
state_reason_
);
}
}
return
version_map
;
}
Status
ModelLifeCycle
::
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
)
{
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
!=
map_
.
end
())
{
auto
vit
=
mit
->
second
.
find
(
model_version
);
if
(
vit
!=
mit
->
second
.
end
())
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
vit
->
second
->
mtx_
);
*
state
=
vit
->
second
->
state_
;
return
Status
::
Success
;
}
}
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"model '"
+
model_name
+
"', version "
+
std
::
to_string
(
model_version
)
+
" is not found"
);
}
Status
ModelLifeCycle
::
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
std
::
shared_ptr
<
Model
>*
model
)
{
LOG_VERBOSE
(
2
)
<<
"GetModel() '"
<<
model_name
<<
"' version "
<<
version
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
mit
=
map_
.
find
(
model_name
);
if
(
mit
==
map_
.
end
())
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' is not found"
);
}
auto
vit
=
mit
->
second
.
find
(
version
);
if
(
vit
==
mit
->
second
.
end
())
{
if
(
version
!=
-
1
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' version "
+
std
::
to_string
(
version
)
+
" is not found"
);
}
// The case where the request is asking for latest version
int64_t
latest
=
-
1
;
for
(
auto
&
version_model
:
mit
->
second
)
{
if
(
version_model
.
first
>
latest
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
version_model
.
second
->
mtx_
);
if
(
version_model
.
second
->
state_
==
ModelReadyState
::
READY
)
{
latest
=
version_model
.
first
;
// Tedious, but have to set handle for any "latest" version
// at the moment to avoid edge case like the following:
// "versions : 1 3 2", version 3 is latest but is requested
// to be unloaded when the iterator is examining version 2,
// then 'model' will ensure version 3 is still valid
*
model
=
version_model
.
second
->
model_
;
}
}
}
if
(
latest
==
-
1
)
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"'"
+
model_name
+
"' has no available versions"
);
}
}
else
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
vit
->
second
->
mtx_
);
if
(
vit
->
second
->
state_
==
ModelReadyState
::
READY
)
{
*
model
=
vit
->
second
->
model_
;
}
else
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"'"
+
model_name
+
"' version "
+
std
::
to_string
(
version
)
+
" is not at ready state"
);
}
}
return
Status
::
Success
;
}
Status
ModelLifeCycle
::
AsyncUnload
(
const
std
::
string
&
model_name
)
{
LOG_VERBOSE
(
2
)
<<
"AsyncUnload() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
if
(
it
==
map_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Model to be unloaded has not been served"
);
}
// Get the existing agent models and notify the unload action
const
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
for
(
auto
&
version
:
it
->
second
)
{
auto
&
model_info
=
version
.
second
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
model_info
->
last_update_ns_
=
now_ns
;
// Unload serving model, for model that is in LOADING state,
// the updated timestamp will be recognized that there is newer update
// on the model info and the load should be aborted
if
(
model_info
->
state_
==
ModelReadyState
::
READY
)
{
if
(
model_info
->
agent_model_list_
!=
nullptr
)
{
// Only log the error because the model should be unloaded regardless
auto
status
=
model_info
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on TRITONREPOAGENT_ACTION_UNLOAD: "
<<
status
.
AsString
();
}
}
// unload
model_info
->
Release
();
}
}
return
Status
::
Success
;
}
Status
ModelLifeCycle
::
AsyncLoad
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
std
::
shared_ptr
<
TritonRepoAgentModelList
>&
agent_model_list
,
std
::
function
<
void
(
Status
)
>&&
OnComplete
)
{
LOG_VERBOSE
(
2
)
<<
"AsyncLoad() '"
<<
model_name
<<
"'"
;
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
if
(
it
==
map_
.
end
())
{
it
=
map_
.
emplace
(
std
::
make_pair
(
model_name
,
VersionMap
())).
first
;
}
std
::
set
<
int64_t
>
versions
;
RETURN_IF_ERROR
(
VersionsToLoad
(
model_path
,
model_name
,
model_config
,
&
versions
));
if
(
versions
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"at least one version must be available under the version policy of "
"model '"
+
model_name
+
"'"
);
}
const
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
std
::
shared_ptr
<
LoadTracker
>
load_tracker
(
new
LoadTracker
(
versions
.
size
(),
now_ns
));
for
(
const
auto
&
version
:
versions
)
{
std
::
unique_ptr
<
ModelInfo
>
linfo
(
new
ModelInfo
(
model_path
,
model_config
,
now_ns
));
ModelInfo
*
model_info
=
linfo
.
get
();
LOG_INFO
<<
"loading: "
<<
model_name
<<
":"
<<
version
;
model_info
->
state_
=
ModelReadyState
::
LOADING
;
model_info
->
state_reason_
.
clear
();
model_info
->
agent_model_list_
=
agent_model_list
;
auto
res
=
it
->
second
.
emplace
(
std
::
make_pair
(
version
,
std
::
unique_ptr
<
ModelInfo
>
()));
if
(
res
.
second
)
{
res
.
first
->
second
=
std
::
move
(
linfo
);
}
else
{
// There is already a record of this model version. Check if the version
// model is being served, if so, the re-load of the version
// should be performed in background to avoid version downtime.
// Otherwise, swap and monitor state for newly loading model.
auto
&
serving_model
=
res
.
first
->
second
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
serving_model
->
mtx_
);
if
(
serving_model
->
state_
==
ModelReadyState
::
READY
)
{
background_models_
[(
uintptr_t
)
model_info
]
=
std
::
move
(
linfo
);
}
else
{
// swap the monitoring model info
serving_model
.
swap
(
linfo
);
// further check the state, put to 'background_models_' to keep
// the object valid if the model is LOADING / UNLOADING, because
// the model info will be accessed by a different thread once the
// operation is completed
if
((
linfo
->
state_
==
ModelReadyState
::
LOADING
)
||
(
linfo
->
state_
==
ModelReadyState
::
UNLOADING
))
{
ModelInfo
*
key
=
linfo
.
get
();
background_models_
[(
uintptr_t
)
key
]
=
std
::
move
(
linfo
);
}
}
}
// Load model asynchronously via thread pool
load_pool_
->
Enqueue
([
this
,
model_name
,
version
,
model_info
,
OnComplete
,
load_tracker
,
is_config_provided
]()
{
CreateModel
(
model_name
,
version
,
model_info
,
is_config_provided
);
OnLoadComplete
(
model_name
,
version
,
model_info
,
OnComplete
,
load_tracker
);
});
}
return
Status
::
Success
;
}
void
ModelLifeCycle
::
CreateModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
const
bool
is_config_provided
)
{
LOG_VERBOSE
(
2
)
<<
"CreateModel() '"
<<
model_name
<<
"' version "
<<
version
;
const
auto
&
model_config
=
model_info
->
model_config_
;
// Create model
Status
status
;
std
::
unique_ptr
<
Model
>
is
;
// If 'backend' is specified in the config then use the new triton
// backend.
if
(
!
model_config
.
backend
().
empty
())
{
std
::
unique_ptr
<
TritonModel
>
model
;
status
=
TritonModel
::
Create
(
server_
,
model_info
->
model_path_
,
cmdline_config_map_
,
host_policy_map_
,
model_name
,
version
,
model_config
,
is_config_provided
,
&
model
);
is
.
reset
(
model
.
release
());
}
else
{
#ifdef TRITON_ENABLE_ENSEMBLE
if
(
model_info
->
is_ensemble_
)
{
status
=
EnsembleModel
::
Create
(
server_
,
model_info
->
model_path_
,
version
,
model_config
,
is_config_provided
,
min_compute_capability_
,
&
is
);
// Complete label provider with label information from involved models
// Must be done here because involved models may not be able to
// obtained from server because this may happen during server
// initialization.
if
(
status
.
IsOk
())
{
std
::
set
<
std
::
string
>
no_label_outputs
;
const
auto
&
label_provider
=
is
->
GetLabelProvider
();
for
(
const
auto
&
output
:
model_config
.
output
())
{
if
(
label_provider
->
GetLabel
(
output
.
name
(),
0
).
empty
())
{
no_label_outputs
.
emplace
(
output
.
name
());
}
}
for
(
const
auto
&
element
:
model_config
.
ensemble_scheduling
().
step
())
{
for
(
const
auto
&
pair
:
element
.
output_map
())
{
// Found model that produce one of the missing output
if
(
no_label_outputs
.
find
(
pair
.
second
)
!=
no_label_outputs
.
end
())
{
std
::
shared_ptr
<
Model
>
model
;
// Safe to obtain model because the ensemble can't be loaded
// until the involved models are ready
GetModel
(
element
.
model_name
(),
element
.
model_version
(),
&
model
);
label_provider
->
AddLabels
(
pair
.
second
,
model
->
GetLabelProvider
()
->
GetLabels
(
pair
.
first
));
}
}
}
}
}
else
#endif // TRITON_ENABLE_ENSEMBLE
{
status
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"unknown platform '"
+
model_config
.
platform
()
+
"'"
);
}
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
if
(
status
.
IsOk
())
{
// [FIXME] better way to manage agent model lifecycle
// Let the deleter also holds a shared pointer copy of agent model list,
// because the reference in ModelInfo can be cleared before the Model object
// is destroyed, and we want agent model to be valid for receiving
// UNLOAD_COMPLETE signal (see ~TritonRepoAgentModelList for detail)
auto
agent_model_list
=
model_info
->
agent_model_list_
;
model_info
->
model_
.
reset
(
is
.
release
(),
ModelDeleter
([
this
,
model_name
,
version
,
model_info
,
agent_model_list
]()
mutable
{
LOG_VERBOSE
(
2
)
<<
"OnDestroy callback() '"
<<
model_name
<<
"' version "
<<
version
;
LOG_INFO
<<
"successfully unloaded '"
<<
model_name
<<
"' version "
<<
version
;
// Update model state as it is fully unloaded
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
model_info
->
mtx_
);
model_info
->
state_
=
ModelReadyState
::
UNAVAILABLE
;
model_info
->
state_reason_
=
"unloaded"
;
}
// Check if the model info is in background, if so, remove from the
// map
std
::
lock_guard
<
std
::
mutex
>
lk
(
this
->
map_mtx_
);
auto
it
=
this
->
background_models_
.
find
((
uintptr_t
)
model_info
);
if
(
it
!=
this
->
background_models_
.
end
())
{
this
->
background_models_
.
erase
(
it
);
}
}));
}
else
{
LOG_ERROR
<<
"failed to load '"
<<
model_name
<<
"' version "
<<
version
<<
": "
<<
status
.
AsString
();
model_info
->
state_
=
ModelReadyState
::
UNAVAILABLE
;
model_info
->
state_reason_
=
status
.
AsString
();
}
}
void
ModelLifeCycle
::
OnLoadComplete
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
std
::
function
<
void
(
Status
)
>
OnComplete
,
std
::
shared_ptr
<
LoadTracker
>
load_tracker
)
{
std
::
lock_guard
<
std
::
mutex
>
tracker_lock
(
load_tracker
->
mtx_
);
++
load_tracker
->
completed_version_cnt_
;
load_tracker
->
load_set_
[
version
]
=
model_info
;
// Version will not be marked ready until all versions are
// ready, this simplify the unloading when one version fails to load as
// all other versions won't have inflight requests
if
(
model_info
->
state_
!=
ModelReadyState
::
LOADING
)
{
load_tracker
->
load_failed_
=
true
;
load_tracker
->
reason_
+=
(
"version "
+
std
::
to_string
(
version
)
+
" is at "
+
ModelReadyStateString
(
model_info
->
state_
)
+
" state: "
+
model_info
->
state_reason_
+
";"
);
}
// Check if all versions are completed and finish the load
if
(
load_tracker
->
completed_version_cnt_
==
load_tracker
->
affected_version_cnt_
)
{
// hold 'map_mtx_' as there will be change onto the model info map
std
::
lock_guard
<
std
::
mutex
>
map_lock
(
map_mtx_
);
auto
it
=
map_
.
find
(
model_name
);
// Check if the load is the latest frontground action on the model
for
(
const
auto
&
version_info
:
it
->
second
)
{
if
(
version_info
.
second
->
last_update_ns_
>
load_tracker
->
last_update_ns_
)
{
load_tracker
->
load_failed_
=
true
;
load_tracker
->
reason_
=
"Newer operation has been applied to the model lifecycle, current "
"load operation is out-dated."
;
break
;
}
}
if
(
load_tracker
->
load_failed_
)
{
// Move agent list out of ModelInfo as it needs to be invoked
// after all ModelInfos are reset
std
::
shared_ptr
<
TritonRepoAgentModelList
>
lagent_list
;
if
(
model_info
->
agent_model_list_
)
{
lagent_list
=
std
::
move
(
model_info
->
agent_model_list_
);
}
// If any of the versions fails to load, abort the load and unload
// all newly loaded versions
for
(
auto
&
loaded
:
load_tracker
->
load_set_
)
{
// Unload directly, the object is being managed either in frontground
// or background
std
::
lock_guard
<
std
::
mutex
>
lock
(
loaded
.
second
->
mtx_
);
if
(
loaded
.
second
->
model_
!=
nullptr
)
{
loaded
.
second
->
Release
();
}
}
if
(
lagent_list
)
{
auto
status
=
lagent_list
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_LOAD_FAIL
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_FAIL: "
<<
status
.
AsString
();
}
}
}
else
{
// Unload any previous loaded versions that are still available
for
(
auto
&
version_info
:
it
->
second
)
{
auto
&
mi
=
version_info
.
second
;
std
::
lock_guard
<
std
::
mutex
>
info_lk
(
mi
->
mtx_
);
if
((
mi
->
state_
==
ModelReadyState
::
READY
)
&&
(
mi
->
last_update_ns_
<
load_tracker
->
last_update_ns_
))
{
if
(
mi
->
agent_model_list_
!=
nullptr
)
{
auto
status
=
mi
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_UNLOAD: "
<<
status
.
AsString
();
}
}
mi
->
Release
();
}
}
// Mark current versions ready and track info in foreground
for
(
auto
&
loaded
:
load_tracker
->
load_set_
)
{
std
::
lock_guard
<
std
::
mutex
>
curr_info_lk
(
loaded
.
second
->
mtx_
);
loaded
.
second
->
state_
=
ModelReadyState
::
READY
;
model_info
->
state_reason_
.
clear
();
LOG_INFO
<<
"successfully loaded '"
<<
model_name
<<
"' version "
<<
version
;
auto
bit
=
background_models_
.
find
((
uintptr_t
)
loaded
.
second
);
// Check if the version model is loaded in background, if so,
// replace and unload the current serving version
if
(
bit
!=
background_models_
.
end
())
{
auto
vit
=
it
->
second
.
find
(
loaded
.
first
);
// Need to lock the previous model info for in case the model is
// loading / unloading, this ensure the model state is consistent
// even when the load / unload is completed.
std
::
lock_guard
<
std
::
mutex
>
prev_info_lk
(
vit
->
second
->
mtx_
);
// swap previous info into local unique pointer
auto
linfo
=
std
::
move
(
bit
->
second
);
vit
->
second
.
swap
(
linfo
);
background_models_
.
erase
(
bit
);
// if previous info is under change, put into 'background_models_'
if
((
linfo
->
state_
==
ModelReadyState
::
LOADING
)
||
(
linfo
->
state_
==
ModelReadyState
::
UNLOADING
))
{
ModelInfo
*
key
=
linfo
.
get
();
background_models_
[(
uintptr_t
)
key
]
=
std
::
move
(
linfo
);
}
}
}
if
(
model_info
->
agent_model_list_
)
{
auto
status
=
model_info
->
agent_model_list_
->
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE: "
<<
status
.
AsString
();
}
}
}
if
(
OnComplete
!=
nullptr
)
{
OnComplete
(
load_tracker
->
load_failed_
?
Status
(
Status
::
Code
::
INVALID_ARG
,
load_tracker
->
reason_
)
:
Status
::
Success
);
}
}
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_lifecycle.h
0 → 100644
View file @
b30f3cdb
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "repo_agent.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "triton/common/thread_pool.h"
namespace
triton
{
namespace
core
{
struct
ModelLifeCycleOptions
{
explicit
ModelLifeCycleOptions
(
const
double
min_compute_capability
,
const
triton
::
common
::
BackendCmdlineConfigMap
&
backend_cmdline_config_map
,
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map
,
const
unsigned
int
model_load_thread_count
)
:
min_compute_capability_
(
min_compute_capability
),
backend_cmdline_config_map_
(
backend_cmdline_config_map
),
host_policy_map_
(
host_policy_map
),
model_load_thread_count_
(
model_load_thread_count
)
{
}
// The minimum supported CUDA compute capability.
const
double
min_compute_capability_
;
// The backend configuration settings specified on the command-line
const
triton
::
common
::
BackendCmdlineConfigMap
&
backend_cmdline_config_map_
;
// The host policy setting used when loading models.
const
triton
::
common
::
HostPolicyCmdlineConfigMap
&
host_policy_map_
;
// Number of the threads to use for concurrently loading models
const
unsigned
int
model_load_thread_count_
;
};
/// Readiness status for models.
enum
class
ModelReadyState
{
// The model is in an unknown state. The model is not available for
// inferencing.
UNKNOWN
,
// The model is ready and available for inferencing.
READY
,
// The model is unavailable, indicating that the model failed to
// load or has been implicitly or explicitly unloaded. The model is
// not available for inferencing.
UNAVAILABLE
,
// The model is being loaded by the inference server. The model is
// not available for inferencing.
LOADING
,
// The model is being unloaded by the inference server. The model is
// not available for inferencing.
UNLOADING
};
/// Get the string representation for a ModelReadyState
const
std
::
string
&
ModelReadyStateString
(
ModelReadyState
state
);
using
VersionStateMap
=
std
::
map
<
int64_t
,
std
::
pair
<
ModelReadyState
,
std
::
string
>>
;
using
ModelStateMap
=
std
::
map
<
std
::
string
,
VersionStateMap
>
;
// Helper class to manage the lifecycle of a list of associated agent models
class
TritonRepoAgentModelList
{
public:
TritonRepoAgentModelList
()
:
last_action_type_
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
){};
~
TritonRepoAgentModelList
()
{
// Using destructor to finish the unload lifecycle without
// explicitly managing the last step in ModelLifecycle.
if
(
last_action_type_
==
TRITONREPOAGENT_ACTION_UNLOAD
)
{
InvokeAgentModels
(
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
);
}
}
Status
AddAgentModel
(
std
::
unique_ptr
<
TritonRepoAgentModel
>&&
agent_model
)
{
agent_models_
.
emplace_back
(
std
::
move
(
agent_model
));
return
Status
::
Success
;
}
size_t
Size
()
{
return
agent_models_
.
size
();
}
TritonRepoAgentModel
*
Back
()
{
return
agent_models_
.
back
().
get
();
}
Status
InvokeAgentModels
(
const
TRITONREPOAGENT_ActionType
action_type
)
{
// Special handling for the current model lifecycle implementation,
// the repo agent may be asked to perform UNLOAD action multiple times,
// and the requests after the first should be ignored.
const
bool
first_unload
=
(
action_type
==
TRITONREPOAGENT_ACTION_UNLOAD
)
&&
(
last_action_type_
!=
TRITONREPOAGENT_ACTION_UNLOAD
);
if
(
!
first_unload
)
{
return
Status
::
Success
;
}
last_action_type_
=
action_type
;
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
case
TRITONREPOAGENT_ACTION_UNLOAD
:
{
for
(
size_t
idx
=
0
;
idx
<
agent_models_
.
size
();
++
idx
)
{
RETURN_IF_ERROR
(
agent_models_
[
idx
]
->
InvokeAgent
(
action_type
));
}
break
;
}
case
TRITONREPOAGENT_ACTION_LOAD_COMPLETE
:
case
TRITONREPOAGENT_ACTION_LOAD_FAIL
:
case
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE
:
{
// reverse order
for
(
size_t
one_pass_idx
=
agent_models_
.
size
();
one_pass_idx
>
0
;
--
one_pass_idx
)
{
RETURN_IF_ERROR
(
agent_models_
[
one_pass_idx
-
1
]
->
InvokeAgent
(
action_type
));
}
break
;
}
}
return
Status
::
Success
;
}
private:
DISALLOW_COPY_AND_ASSIGN
(
TritonRepoAgentModelList
);
std
::
vector
<
std
::
unique_ptr
<
TritonRepoAgentModel
>>
agent_models_
;
TRITONREPOAGENT_ActionType
last_action_type_
;
};
class
InferenceServer
;
class
Model
;
class
ModelLifeCycle
{
public:
static
Status
Create
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
,
std
::
unique_ptr
<
ModelLifeCycle
>*
life_cycle
);
~
ModelLifeCycle
()
{
// Explicitly clean up thread pool first to clean up any pending callbacks
// that may modify model lifecycle members
load_pool_
.
reset
();
map_
.
clear
();
}
// Start loading model with specified versions asynchronously.
// All versions that are being served will be unloaded only after
// the load is finished sucessfully.
Status
AsyncLoad
(
const
std
::
string
&
model_name
,
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
std
::
shared_ptr
<
TritonRepoAgentModelList
>&
agent_model_list
,
std
::
function
<
void
(
Status
)
>&&
OnComplete
);
// Unload model asynchronously.
Status
AsyncUnload
(
const
std
::
string
&
model_name
);
// Get specified version of the model. Latest ready version will
// be retrieved if 'version' is -1. Return error if the version specified is
// not found or it is not ready.
Status
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
std
::
shared_ptr
<
Model
>*
model
);
// Get the ModelStateMap representation of the live models. A model is
// live if at least one of the versions is not unknown nor unavailable.
// If 'strict_readiness' is true, a model is only live if
// at least one of the versions is ready.
const
ModelStateMap
LiveModelStates
(
bool
strict_readiness
=
false
);
// Get the ModelStateMap representation of the models.
const
ModelStateMap
ModelStates
();
// Get the VersionStateMap representation of the specified model.
const
VersionStateMap
VersionStates
(
const
std
::
string
&
model_name
);
// Get the state of a specific model version.
Status
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
);
// Instruct the model to stop accepting new inference requests.
Status
StopAllModels
();
// Return the number of in-flight inference if any, model versions
// that don't have in-flight inferences will not be included.
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
InflightStatus
();
private:
struct
ModelInfo
{
ModelInfo
(
const
std
::
string
&
model_path
,
const
inference
::
ModelConfig
&
model_config
,
const
uint64_t
last_update_ns
)
:
model_config_
(
model_config
),
model_path_
(
model_path
),
#ifdef TRITON_ENABLE_ENSEMBLE
is_ensemble_
(
model_config
.
platform
()
==
kEnsemblePlatform
),
#else
is_ensemble_
(
false
),
#endif // TRITON_ENABLE_ENSEMBLE
last_update_ns_
(
last_update_ns
),
state_
(
ModelReadyState
::
UNKNOWN
)
{
}
// Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in
// model state. Note that 'mtx_' should be acquired before invoking this
// function to prevent possible data race.
void
Release
()
{
state_
=
ModelReadyState
::
UNLOADING
;
state_reason_
.
clear
();
agent_model_list_
.
reset
();
model_
.
reset
();
}
const
inference
::
ModelConfig
model_config_
;
const
std
::
string
model_path_
;
const
bool
is_ensemble_
;
std
::
mutex
mtx_
;
uint64_t
last_update_ns_
;
ModelReadyState
state_
;
std
::
string
state_reason_
;
// flyweight
std
::
shared_ptr
<
TritonRepoAgentModelList
>
agent_model_list_
;
std
::
shared_ptr
<
Model
>
model_
;
};
struct
LoadTracker
{
LoadTracker
(
const
size_t
affected_version_cnt
,
const
uint64_t
last_update_ns
)
:
last_update_ns_
(
last_update_ns
),
affected_version_cnt_
(
affected_version_cnt
),
load_failed_
(
false
),
completed_version_cnt_
(
0
)
{
}
const
uint64_t
last_update_ns_
;
const
size_t
affected_version_cnt_
;
std
::
mutex
mtx_
;
bool
load_failed_
;
std
::
string
reason_
;
size_t
completed_version_cnt_
;
std
::
map
<
int64_t
,
ModelInfo
*>
load_set_
;
};
ModelLifeCycle
(
InferenceServer
*
server
,
const
ModelLifeCycleOptions
&
options
)
:
server_
(
server
),
min_compute_capability_
(
options
.
min_compute_capability_
),
cmdline_config_map_
(
options
.
backend_cmdline_config_map_
),
host_policy_map_
(
options
.
host_policy_map_
)
{
load_pool_
.
reset
(
new
triton
::
common
::
ThreadPool
(
std
::
max
(
1u
,
options
.
model_load_thread_count_
)));
}
void
CreateModel
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
const
bool
is_config_provided
);
// Callback function template for model load.
// 'OnComplete' needs to be passed by value for now as there can be
// multiple versions to be loaded and each holds a copy of
// the 'OnComplete' callback.
void
OnLoadComplete
(
const
std
::
string
&
model_name
,
const
int64_t
version
,
ModelInfo
*
model_info
,
std
::
function
<
void
(
Status
)
>
OnComplete
,
std
::
shared_ptr
<
LoadTracker
>
load_tracker
);
// Mutex for 'map_' and 'background_models_'
std
::
mutex
map_mtx_
;
using
VersionMap
=
std
::
map
<
int64_t
,
std
::
unique_ptr
<
ModelInfo
>>
;
using
ModelMap
=
std
::
map
<
std
::
string
,
VersionMap
>
;
ModelMap
map_
;
// Models that are being loaded / unloaded in background
std
::
map
<
uintptr_t
,
std
::
unique_ptr
<
ModelInfo
>>
background_models_
;
InferenceServer
*
server_
;
const
double
min_compute_capability_
;
const
triton
::
common
::
BackendCmdlineConfigMap
cmdline_config_map_
;
const
triton
::
common
::
HostPolicyCmdlineConfigMap
host_policy_map_
;
// Fixed-size thread pool to load models at specified concurrency
std
::
unique_ptr
<
triton
::
common
::
ThreadPool
>
load_pool_
;
};
}}
// namespace triton::core
3rdparty/core-r22.12/src/model_repository_manager.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_repository_manager.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "ensemble_utils.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace
triton
{
namespace
core
{
namespace
{
static
std
::
string
file_prefix
=
"file:"
;
// Internal repo agent used for model file override
class
LocalizeRepoAgent
:
public
TritonRepoAgent
{
public:
LocalizeRepoAgent
()
:
TritonRepoAgent
(
"ModelRepositoryManager::LocalizeRepoAgent"
)
{
// Callbacks below interact with TritonRepoAgentModel directly knowing that
// it is the internal implementation of TRITONREPOAGENT_AgentModel
model_action_fn_
=
[](
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
,
const
TRITONREPOAGENT_ActionType
action_type
)
->
TRITONSERVER_Error
*
{
auto
agent_model
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
switch
(
action_type
)
{
case
TRITONREPOAGENT_ACTION_LOAD
:
{
// localize the override files for model loading,
// as currently the model is expected to load from local directory
const
char
*
temp_dir_cstr
=
nullptr
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
AcquireMutableLocation
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
&
temp_dir_cstr
));
const
std
::
string
temp_dir
=
temp_dir_cstr
;
const
auto
&
files
=
*
reinterpret_cast
<
std
::
vector
<
const
InferenceParameter
*>*>
(
agent_model
->
State
());
bool
found_config
=
false
;
for
(
const
auto
&
file
:
files
)
{
if
(
file
->
Name
()
==
"config"
)
{
if
(
file
->
Type
()
!=
TRITONSERVER_PARAMETER_STRING
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"Config parameter 'config' must have string type for its "
"value"
);
}
inference
::
ModelConfig
config
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
JsonToModelConfig
(
file
->
ValueString
(),
1
/* config_version */
,
&
config
));
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
WriteTextProto
(
JoinPath
({
temp_dir
,
kModelConfigPbTxt
}),
config
));
found_config
=
true
;
}
else
if
(
file
->
Name
().
rfind
(
file_prefix
,
0
)
==
0
)
{
if
(
file
->
Type
()
!=
TRITONSERVER_PARAMETER_BYTES
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
(
std
::
string
(
"File parameter '"
)
+
file
->
Name
()
+
"' must have bytes type for its value"
)
.
c_str
());
}
// Save model file to the instructed directory
// mkdir
const
std
::
string
file_path
=
JoinPath
({
temp_dir
,
file
->
Name
().
substr
(
file_prefix
.
size
())});
const
std
::
string
dir
=
DirName
(
file_path
);
bool
dir_exist
=
false
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
FileExists
(
dir
,
&
dir_exist
));
if
(
dir_exist
)
{
bool
is_dir
=
false
;
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
IsDirectory
(
dir
,
&
is_dir
));
if
(
!
is_dir
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
(
std
::
string
(
"Invalid file parameter '"
)
+
file
->
Name
()
+
"', directory has been created as a file"
)
.
c_str
());
}
}
else
{
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
MakeDirectory
(
dir
,
true
/* recursive */
));
}
// write
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
WriteBinaryFile
(
file_path
,
reinterpret_cast
<
const
char
*>
(
file
->
ValuePointer
()),
file
->
ValueByteSize
()));
}
}
if
(
!
found_config
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"Load parameter 'config' must be specified for model file "
"override"
);
}
// Commit the temporary directory
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
SetLocation
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
temp_dir_cstr
));
break
;
}
default:
break
;
}
return
nullptr
;
// success
};
model_fini_fn_
=
[](
TRITONREPOAGENT_Agent
*
agent
,
TRITONREPOAGENT_AgentModel
*
model
)
->
TRITONSERVER_Error
*
{
auto
agent_model
=
reinterpret_cast
<
TritonRepoAgentModel
*>
(
model
);
RETURN_TRITONSERVER_ERROR_IF_ERROR
(
agent_model
->
DeleteMutableLocation
());
return
nullptr
;
// success
};
}
};
Status
CreateAgentModelListWithLoadAction
(
const
inference
::
ModelConfig
&
original_model_config
,
const
std
::
string
&
original_model_path
,
std
::
shared_ptr
<
TritonRepoAgentModelList
>*
agent_model_list
)
{
if
(
original_model_config
.
has_model_repository_agents
())
{
// Trick to append user specified repo agent on top of internal ones
std
::
shared_ptr
<
TritonRepoAgentModelList
>
lagent_model_list
;
if
(
*
agent_model_list
!=
nullptr
)
{
lagent_model_list
=
std
::
move
(
*
agent_model_list
);
}
else
{
lagent_model_list
.
reset
(
new
TritonRepoAgentModelList
());
}
FileSystemType
filesystem_type
;
RETURN_IF_ERROR
(
GetFileSystemType
(
original_model_path
,
&
filesystem_type
));
TRITONREPOAGENT_ArtifactType
artifact_type
=
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
;
if
(
filesystem_type
!=
FileSystemType
::
LOCAL
)
{
artifact_type
=
TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM
;
}
const
char
*
location
=
original_model_path
.
c_str
();
inference
::
ModelConfig
model_config
=
original_model_config
;
for
(
const
auto
&
agent_config
:
original_model_config
.
model_repository_agents
().
agents
())
{
std
::
shared_ptr
<
TritonRepoAgent
>
agent
;
RETURN_IF_ERROR
(
TritonRepoAgentManager
::
CreateAgent
(
agent_config
.
name
(),
&
agent
));
TritonRepoAgent
::
Parameters
agent_params
;
for
(
const
auto
&
parameter
:
agent_config
.
parameters
())
{
agent_params
.
emplace_back
(
parameter
.
first
,
parameter
.
second
);
}
std
::
unique_ptr
<
TritonRepoAgentModel
>
agent_model
;
if
(
lagent_model_list
->
Size
()
!=
0
)
{
lagent_model_list
->
Back
()
->
Location
(
&
artifact_type
,
&
location
);
const
auto
config_path
=
JoinPath
({
location
,
kModelConfigPbTxt
});
if
(
!
ReadTextProto
(
config_path
,
&
model_config
).
IsOk
())
{
model_config
.
Clear
();
}
}
RETURN_IF_ERROR
(
TritonRepoAgentModel
::
Create
(
artifact_type
,
location
,
model_config
,
agent
,
agent_params
,
&
agent_model
));
RETURN_IF_ERROR
(
agent_model
->
InvokeAgent
(
TRITONREPOAGENT_ACTION_LOAD
));
lagent_model_list
->
AddAgentModel
(
std
::
move
(
agent_model
));
}
*
agent_model_list
=
std
::
move
(
lagent_model_list
);
}
return
Status
::
Success
;
}
int64_t
GetModifiedTime
(
const
std
::
string
&
path
)
{
// If there is an error in any step the fall-back default
// modification time is 0. This means that in error cases 'path'
// will show as not modified. This is the safe fall-back to avoid
// assuming a model is constantly being modified.
bool
path_is_dir
;
Status
status
=
IsDirectory
(
path
,
&
path_is_dir
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
// If 'path' is a file return its mtime. Otherwise, using the modification
// time of the directory as baseline in case of file deletion
int64_t
mtime
=
0
;
status
=
FileModificationTime
(
path
,
&
mtime
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
if
(
!
path_is_dir
)
{
return
mtime
;
}
// 'path' is a directory. Return the most recent mtime of the
// contents of the directory.
std
::
set
<
std
::
string
>
contents
;
status
=
GetDirectoryContents
(
path
,
&
contents
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Failed to determine modification time for '"
<<
path
<<
"': "
<<
status
.
AsString
();
return
0
;
}
for
(
const
auto
&
child
:
contents
)
{
const
auto
full_path
=
JoinPath
({
path
,
child
});
mtime
=
std
::
max
(
mtime
,
GetModifiedTime
(
full_path
));
}
return
mtime
;
}
// Return true if any file in the subdirectory root at 'path' has been
// modified more recently than 'last'. Return the most-recent modified
// time in 'last'.
bool
IsModified
(
const
std
::
string
&
path
,
int64_t
*
last_ns
)
{
const
int64_t
repo_ns
=
GetModifiedTime
(
path
);
bool
modified
=
repo_ns
>
*
last_ns
;
*
last_ns
=
repo_ns
;
return
modified
;
}
}
// namespace
struct
ModelRepositoryManager
::
ModelInfo
{
ModelInfo
(
const
int64_t
mtime_nsec
,
const
int64_t
prev_mtime_ns
,
const
std
::
string
&
model_path
)
:
mtime_nsec_
(
mtime_nsec
),
prev_mtime_ns_
(
prev_mtime_ns
),
explicitly_load_
(
true
),
model_path_
(
model_path
),
is_config_provided_
(
false
)
{
}
ModelInfo
()
:
mtime_nsec_
(
0
),
prev_mtime_ns_
(
0
),
explicitly_load_
(
true
),
is_config_provided_
(
false
)
{
}
int64_t
mtime_nsec_
;
int64_t
prev_mtime_ns_
;
bool
explicitly_load_
;
inference
::
ModelConfig
model_config_
;
std
::
string
model_path_
;
// Temporary location to hold agent model list before creating the model
// the ownership must transfer to ModelLifeCycle to ensure
// the agent model life cycle is handled properly.
std
::
shared_ptr
<
TritonRepoAgentModelList
>
agent_model_list_
;
bool
is_config_provided_
;
};
ModelRepositoryManager
::
ModelRepositoryManager
(
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
bool
autofill
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
)
:
repository_paths_
(
repository_paths
),
autofill_
(
autofill
),
polling_enabled_
(
polling_enabled
),
model_control_enabled_
(
model_control_enabled
),
min_compute_capability_
(
min_compute_capability
),
model_life_cycle_
(
std
::
move
(
life_cycle
))
{
}
ModelRepositoryManager
::~
ModelRepositoryManager
()
{}
Status
ModelRepositoryManager
::
Create
(
InferenceServer
*
server
,
const
std
::
string
&
server_version
,
const
std
::
set
<
std
::
string
>&
repository_paths
,
const
std
::
set
<
std
::
string
>&
startup_models
,
const
bool
strict_model_config
,
const
bool
polling_enabled
,
const
bool
model_control_enabled
,
const
ModelLifeCycleOptions
&
life_cycle_options
,
std
::
unique_ptr
<
ModelRepositoryManager
>*
model_repository_manager
)
{
// The rest only matters if repository path is valid directory
for
(
const
auto
&
path
:
repository_paths
)
{
bool
path_is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
path_is_dir
));
if
(
!
path_is_dir
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"repository path is not a valid directory"
);
}
}
if
(
polling_enabled
&&
model_control_enabled
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"cannot enable both polling and explicit model control"
);
}
std
::
unique_ptr
<
ModelLifeCycle
>
life_cycle
;
RETURN_IF_ERROR
(
ModelLifeCycle
::
Create
(
server
,
life_cycle_options
,
&
life_cycle
));
// Not setting the smart pointer directly to simplify clean up
std
::
unique_ptr
<
ModelRepositoryManager
>
local_manager
(
new
ModelRepositoryManager
(
repository_paths
,
!
strict_model_config
,
polling_enabled
,
model_control_enabled
,
life_cycle_options
.
min_compute_capability_
,
std
::
move
(
life_cycle
)));
*
model_repository_manager
=
std
::
move
(
local_manager
);
// Support loading all models on startup in explicit model control mode with
// special startup_model name "*". This does not imply support for pattern
// matching in model names.
bool
load_all_models_on_startup
=
false
;
if
((
startup_models
.
find
(
"*"
)
!=
startup_models
.
end
())
&&
model_control_enabled
)
{
if
(
startup_models
.
size
()
>
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Wildcard model name '*' must be the ONLY startup model "
"if specified at all."
);
}
load_all_models_on_startup
=
true
;
}
bool
all_models_polled
=
true
;
if
(
!
model_control_enabled
||
load_all_models_on_startup
)
{
// only error happens before model load / unload will be return
// model loading / unloading error will be printed but ignored
RETURN_IF_ERROR
(
(
*
model_repository_manager
)
->
PollAndUpdateInternal
(
&
all_models_polled
));
}
else
{
// Load each specified startup_model
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
models
;
for
(
const
auto
&
model_name
:
startup_models
)
{
models
[
model_name
];
}
RETURN_IF_ERROR
(
(
*
model_repository_manager
)
->
LoadUnloadModels
(
models
,
ActionType
::
LOAD
,
false
,
&
all_models_polled
));
}
if
(
!
all_models_polled
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
// Some models may failed to be loaded after model manager is created,
// return proper error and let function caller decide whether to proceed.
for
(
const
auto
&
model
:
(
*
model_repository_manager
)
->
infos_
)
{
const
auto
version_states
=
(
*
model_repository_manager
)
->
model_life_cycle_
->
VersionStates
(
model
.
first
);
// Return general error message, detail of each model's loading state
// is logged separately.
if
(
version_states
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
for
(
const
auto
&
state
:
version_states
)
{
if
(
state
.
second
.
first
!=
ModelReadyState
::
READY
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load all models"
);
}
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
PollAndUpdate
()
{
if
(
!
polling_enabled_
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"polling is disabled"
);
}
bool
all_models_polled
;
return
PollAndUpdateInternal
(
&
all_models_polled
);
}
Status
ModelRepositoryManager
::
PollAndUpdateInternal
(
bool
*
all_models_polled
)
{
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
std
::
set
<
std
::
string
>
added
,
deleted
,
modified
,
unmodified
;
// We don't modify 'infos_' in place to minimize how long we need to
// hold the lock and also prevent any partial changes to do an error
// during processing.
ModelInfoMap
new_infos
;
// Each subdirectory of repository path is a model directory from
// which we read the model configuration.
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
subdirs
;
RETURN_IF_ERROR
(
Poll
(
subdirs
,
&
added
,
&
deleted
,
&
modified
,
&
unmodified
,
&
new_infos
,
all_models_polled
));
// Anything in 'infos_' that is not in "added", "modified", or
// "unmodified" is deleted.
for
(
const
auto
&
pr
:
infos_
)
{
if
((
added
.
find
(
pr
.
first
)
==
added
.
end
())
&&
(
modified
.
find
(
pr
.
first
)
==
modified
.
end
())
&&
(
unmodified
.
find
(
pr
.
first
)
==
unmodified
.
end
()))
{
deleted
.
insert
(
pr
.
first
);
}
}
// Nothing to do if no model adds, deletes or modifies.
if
(
added
.
empty
()
&&
deleted
.
empty
()
&&
modified
.
empty
())
{
return
Status
::
Success
;
}
infos_
.
swap
(
new_infos
);
UpdateDependencyGraph
(
added
,
deleted
,
modified
);
for
(
const
auto
&
name
:
deleted
)
{
model_life_cycle_
->
AsyncUnload
(
name
);
}
// model loading / unloading error will be printed but ignored
LoadModelByDependency
();
return
Status
::
Success
;
}
std
::
map
<
std
::
string
,
Status
>
ModelRepositoryManager
::
LoadModelByDependency
()
{
std
::
map
<
std
::
string
,
Status
>
res
;
struct
ModelState
{
ModelState
(
DependencyNode
*
node
)
:
node_
(
node
),
status_
(
Status
::
Success
)
{}
DependencyNode
*
node_
;
Status
status_
;
std
::
promise
<
void
>
ready_
;
};
NodeSet
loaded_models
;
auto
set_pair
=
ModelsToLoadUnload
(
loaded_models
);
// Loop until all model are loaded / unloaded
while
((
!
set_pair
.
first
.
empty
())
||
(
!
set_pair
.
second
.
empty
()))
{
loaded_models
.
clear
();
// Unload invalid models first
for
(
auto
&
invalid_model
:
set_pair
.
second
)
{
model_life_cycle_
->
AsyncUnload
(
invalid_model
->
model_name_
);
LOG_ERROR
<<
invalid_model
->
status_
.
AsString
();
invalid_model
->
loaded_versions_
=
std
::
set
<
int64_t
>
();
loaded_models
.
emplace
(
invalid_model
);
}
// load valid models and wait for load results
std
::
vector
<
std
::
unique_ptr
<
ModelState
>>
model_states
;
for
(
auto
&
valid_model
:
set_pair
.
first
)
{
model_states
.
emplace_back
(
new
ModelState
(
valid_model
));
auto
model_state
=
model_states
.
back
().
get
();
const
auto
itr
=
infos_
.
find
(
valid_model
->
model_name_
);
auto
status
=
model_life_cycle_
->
AsyncLoad
(
valid_model
->
model_name_
,
itr
->
second
->
model_path_
,
valid_model
->
model_config_
,
itr
->
second
->
is_config_provided_
,
itr
->
second
->
agent_model_list_
,
[
model_state
](
Status
load_status
)
{
model_state
->
status_
=
load_status
;
model_state
->
ready_
.
set_value
();
});
if
(
!
status
.
IsOk
())
{
model_state
->
status_
=
status
;
model_state
->
ready_
.
set_value
();
LOG_ERROR
<<
"failed to load model '"
<<
valid_model
->
model_name_
<<
"': "
<<
status
.
Message
();
}
loaded_models
.
emplace
(
valid_model
);
}
for
(
auto
&
model_state
:
model_states
)
{
model_state
->
ready_
.
get_future
().
wait
();
res
[
model_state
->
node_
->
model_name_
]
=
model_state
->
status_
;
const
auto
version_state
=
model_life_cycle_
->
VersionStates
(
model_state
->
node_
->
model_name_
);
model_state
->
node_
->
loaded_versions_
.
clear
();
for
(
const
auto
&
vs
:
version_state
)
{
if
(
vs
.
second
.
first
==
ModelReadyState
::
READY
)
{
model_state
->
node_
->
loaded_versions_
.
emplace
(
vs
.
first
);
}
}
// If the model failed to load, should revert the timestamp to
// ensure the next load request will attempt to load the model again
// for operation consistency.
if
(
!
model_state
->
status_
.
IsOk
())
{
auto
&
model_info
=
infos_
.
find
(
model_state
->
node_
->
model_name_
)
->
second
;
model_info
->
mtime_nsec_
=
model_info
->
prev_mtime_ns_
;
}
}
set_pair
=
ModelsToLoadUnload
(
loaded_models
);
}
// Clear temporary stored agent model list after all loads are triggerred
for
(
auto
&
info
:
infos_
)
{
info
.
second
->
agent_model_list_
.
reset
();
}
return
res
;
}
Status
ModelRepositoryManager
::
LoadUnloadModel
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
"explicit model load / unload is not allowed if polling is enabled"
);
}
if
(
models
.
size
()
>
1
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"explicit load / unload multiple models is not currently supported"
);
}
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
bool
polled
=
true
;
RETURN_IF_ERROR
(
LoadUnloadModels
(
models
,
type
,
unload_dependents
,
&
polled
));
// Check if model is loaded / unloaded properly
const
auto
&
model_name
=
models
.
begin
()
->
first
;
if
(
!
polled
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', failed to poll from model repository"
);
}
const
auto
version_states
=
model_life_cycle_
->
VersionStates
(
model_name
);
if
(
type
==
ActionType
::
LOAD
)
{
if
(
version_states
.
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', no version is available"
);
}
auto
it
=
infos_
.
find
(
model_name
);
if
(
it
==
infos_
.
end
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to load '"
+
model_name
+
"', failed to poll from model repository"
);
}
}
else
{
std
::
string
ready_version_str
;
for
(
const
auto
&
version_state
:
version_states
)
{
if
(
version_state
.
second
.
first
==
ModelReadyState
::
READY
)
{
ready_version_str
+=
std
::
to_string
(
version_state
.
first
);
ready_version_str
+=
","
;
}
}
if
(
!
ready_version_str
.
empty
())
{
ready_version_str
.
pop_back
();
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to unload '"
+
model_name
+
"', versions that are still available: "
+
ready_version_str
);
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
LoadUnloadModels
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
const
ActionType
type
,
const
bool
unload_dependents
,
bool
*
all_models_polled
)
{
auto
status
=
Status
::
Success
;
*
all_models_polled
=
true
;
// Update ModelInfo related to file system accordingly
std
::
set
<
std
::
string
>
added
,
deleted
,
modified
,
unmodified
;
{
if
(
type
==
ActionType
::
UNLOAD
)
{
for
(
const
auto
&
model
:
models
)
{
deleted
.
insert
(
model
.
first
);
}
}
// ActionType::LOAD and in model control mode
else
{
std
::
set
<
std
::
string
>
checked_models
;
auto
current_models
=
models
;
for
(
const
auto
&
model
:
models
)
{
checked_models
.
emplace
(
model
.
first
);
}
ModelInfoMap
new_infos
;
#ifdef TRITON_ENABLE_ENSEMBLE
bool
first_iteration
=
true
;
#endif // TRITON_ENABLE_ENSEMBLE
while
(
!
current_models
.
empty
())
{
bool
polled
=
true
;
RETURN_IF_ERROR
(
Poll
(
current_models
,
&
added
,
&
deleted
,
&
modified
,
&
unmodified
,
&
new_infos
,
&
polled
));
*
all_models_polled
&=
polled
;
// More models should be polled if the polled models are ensembles
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>
next_models
;
#ifdef TRITON_ENABLE_ENSEMBLE
for
(
const
auto
&
model
:
current_models
)
{
auto
it
=
new_infos
.
find
(
model
.
first
);
// Some models may be marked as deleted and not in 'new_infos'
if
(
it
!=
new_infos
.
end
())
{
it
->
second
->
explicitly_load_
=
first_iteration
;
const
auto
&
config
=
it
->
second
->
model_config_
;
if
(
config
.
has_ensemble_scheduling
())
{
for
(
const
auto
&
step
:
config
.
ensemble_scheduling
().
step
())
{
bool
need_poll
=
checked_models
.
emplace
(
step
.
model_name
()).
second
;
if
(
need_poll
)
{
next_models
[
step
.
model_name
()];
}
}
}
}
}
first_iteration
=
false
;
#endif // TRITON_ENABLE_ENSEMBLE
current_models
.
swap
(
next_models
);
}
// Only update the infos when all validation is completed
for
(
const
auto
&
model_name
:
added
)
{
auto
nitr
=
new_infos
.
find
(
model_name
);
infos_
.
emplace
(
model_name
,
std
::
move
(
nitr
->
second
));
}
for
(
const
auto
&
model_name
:
modified
)
{
auto
nitr
=
new_infos
.
find
(
model_name
);
auto
itr
=
infos_
.
find
(
model_name
);
itr
->
second
=
std
::
move
(
nitr
->
second
);
}
}
}
std
::
set
<
std
::
string
>
deleted_dependents
;
// Update dependency graph and load
UpdateDependencyGraph
(
added
,
deleted
,
modified
,
unload_dependents
?
&
deleted_dependents
:
nullptr
);
// The models are in 'deleted' either when they are asked to be unloaded or
// they are not found / are duplicated across all model repositories.
// In all cases, should unload them and remove from 'infos_' explicitly.
for
(
const
auto
&
name
:
(
unload_dependents
?
deleted_dependents
:
deleted
))
{
infos_
.
erase
(
name
);
model_life_cycle_
->
AsyncUnload
(
name
);
}
// load / unload the models affected, and check the load status of
// the requested models
const
auto
&
load_status
=
LoadModelByDependency
();
if
(
status
.
IsOk
()
&&
(
type
==
ActionType
::
LOAD
))
{
std
::
string
load_error_message
=
""
;
for
(
const
auto
&
model
:
models
)
{
auto
it
=
load_status
.
find
(
model
.
first
);
// If 'model.first' not in load status, it means the (re-)load is not
// necessary because there is no change in the model's directory
if
((
it
!=
load_status
.
end
())
&&
!
it
->
second
.
IsOk
())
{
load_error_message
+=
(
"load failed for model '"
+
model
.
first
+
"': "
+
it
->
second
.
Message
()
+
"
\n
"
);
}
}
if
(
!
load_error_message
.
empty
())
{
status
=
Status
(
Status
::
Code
::
INVALID_ARG
,
load_error_message
);
}
}
return
status
;
}
Status
ModelRepositoryManager
::
UnloadAllModels
()
{
Status
status
;
for
(
const
auto
&
name_info
:
infos_
)
{
Status
unload_status
=
model_life_cycle_
->
AsyncUnload
(
name_info
.
first
);
if
(
!
unload_status
.
IsOk
())
{
status
=
Status
(
unload_status
.
ErrorCode
(),
"Failed to gracefully unload models: "
+
unload_status
.
Message
());
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
StopAllModels
()
{
return
model_life_cycle_
->
StopAllModels
();
}
const
std
::
set
<
std
::
tuple
<
std
::
string
,
int64_t
,
size_t
>>
ModelRepositoryManager
::
InflightStatus
()
{
return
model_life_cycle_
->
InflightStatus
();
}
const
ModelStateMap
ModelRepositoryManager
::
LiveModelStates
(
bool
strict_readiness
)
{
return
model_life_cycle_
->
LiveModelStates
(
strict_readiness
);
}
const
ModelStateMap
ModelRepositoryManager
::
ModelStates
()
{
return
model_life_cycle_
->
ModelStates
();
}
const
VersionStateMap
ModelRepositoryManager
::
VersionStates
(
const
std
::
string
&
model_name
)
{
return
model_life_cycle_
->
VersionStates
(
model_name
);
}
Status
ModelRepositoryManager
::
ModelState
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
ModelReadyState
*
state
)
{
return
model_life_cycle_
->
ModelState
(
model_name
,
model_version
,
state
);
}
Status
ModelRepositoryManager
::
RepositoryIndex
(
const
bool
ready_only
,
std
::
vector
<
ModelIndex
>*
index
)
{
std
::
set
<
std
::
string
>
seen_models
;
std
::
set
<
std
::
string
>
duplicate_models
;
for
(
const
auto
&
repository_path
:
repository_paths_
)
{
// For any mapped models in this repository, save the mapping
// from their subdirectory name to model name.
std
::
map
<
std
::
string
,
std
::
string
>
models_in_repo
;
for
(
const
auto
&
mapping_it
:
model_mappings_
)
{
if
(
mapping_it
.
second
.
first
==
repository_path
)
{
models_in_repo
.
emplace
(
BaseName
(
mapping_it
.
second
.
second
),
mapping_it
.
first
);
}
}
std
::
set
<
std
::
string
>
subdirs
;
RETURN_IF_ERROR
(
GetDirectorySubdirs
(
repository_path
,
&
subdirs
));
for
(
const
auto
&
subdir
:
subdirs
)
{
auto
model
=
subdir
;
auto
model_it
=
models_in_repo
.
find
(
subdir
);
if
(
model_it
!=
models_in_repo
.
end
())
{
model
=
model_it
->
second
;
}
if
(
seen_models
.
find
(
model
)
!=
seen_models
.
end
())
{
duplicate_models
.
insert
(
model
);
}
seen_models
.
insert
(
model
);
}
}
ModelStateMap
states
=
ModelStates
();
for
(
const
auto
&
model
:
seen_models
)
{
// If the same model appears in multiple repostories then show it
// as unavailable since duplicate models are not allowed to load.
if
(
duplicate_models
.
find
(
model
)
!=
duplicate_models
.
end
())
{
index
->
emplace_back
(
model
,
-
1
/* version */
,
ModelReadyState
::
UNAVAILABLE
,
MODEL_READY_REASON_DUPLICATE
);
continue
;
}
// If there is any version/state/reason associated with the model
// then include that in the index.
auto
sitr
=
states
.
find
(
model
);
if
(
sitr
==
states
.
end
())
{
if
(
!
ready_only
)
{
index
->
emplace_back
(
model
);
}
}
else
{
for
(
const
auto
&
pr
:
sitr
->
second
)
{
if
(
!
ready_only
||
(
pr
.
second
.
first
==
ModelReadyState
::
READY
))
{
index
->
emplace_back
(
model
,
pr
.
first
,
pr
.
second
.
first
,
pr
.
second
.
second
);
}
}
}
}
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
GetModel
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
,
std
::
shared_ptr
<
Model
>*
model
)
{
Status
status
=
model_life_cycle_
->
GetModel
(
model_name
,
model_version
,
model
);
if
(
!
status
.
IsOk
())
{
model
->
reset
();
status
=
Status
(
status
.
ErrorCode
(),
"Request for unknown model: "
+
status
.
Message
());
}
return
status
;
}
Status
ModelRepositoryManager
::
Poll
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
const
InferenceParameter
*>>&
models
,
std
::
set
<
std
::
string
>*
added
,
std
::
set
<
std
::
string
>*
deleted
,
std
::
set
<
std
::
string
>*
modified
,
std
::
set
<
std
::
string
>*
unmodified
,
ModelInfoMap
*
updated_infos
,
bool
*
all_models_polled
)
{
*
all_models_polled
=
true
;
// empty path is the special case to indicate the model should be loaded
// from override file content in 'models'.
std
::
map
<
std
::
string
,
std
::
string
>
model_to_path
;
// If no model is specified, poll all models in all model repositories.
// Otherwise, only poll the specified models
if
(
models
.
empty
())
{
std
::
set
<
std
::
string
>
duplicated_models
;
for
(
const
auto
&
repository_path
:
repository_paths_
)
{
std
::
set
<
std
::
string
>
subdirs
;
Status
status
=
GetDirectorySubdirs
(
repository_path
,
&
subdirs
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll model repository '"
<<
repository_path
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
else
{
for
(
const
auto
&
subdir
:
subdirs
)
{
if
(
!
model_to_path
.
emplace
(
subdir
,
JoinPath
({
repository_path
,
subdir
}))
.
second
)
{
duplicated_models
.
insert
(
subdir
);
*
all_models_polled
=
false
;
}
}
}
}
// If the model is not unique, mark as deleted to unload it
for
(
const
auto
&
model
:
duplicated_models
)
{
model_to_path
.
erase
(
model
);
deleted
->
insert
(
model
);
LOG_ERROR
<<
"failed to poll model '"
<<
model
<<
"': not unique across all model repositories"
;
}
}
// If models are specified, this is explicit model control mode.
else
{
for
(
const
auto
&
model
:
models
)
{
// Skip repository polling if override model files
if
(
ModelDirectoryOverride
(
model
.
second
))
{
model_to_path
.
emplace
(
model
.
first
,
""
);
continue
;
}
// Check model mapping first to see if matching model to load.
bool
exists
=
false
;
auto
model_it
=
model_mappings_
.
find
(
model
.
first
);
if
(
model_it
!=
model_mappings_
.
end
())
{
bool
exists_in_this_repo
=
false
;
auto
full_path
=
model_it
->
second
.
second
;
Status
status
=
FileExists
(
full_path
,
&
exists_in_this_repo
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll mapped path '"
<<
full_path
<<
"' for model '"
<<
model
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
if
(
exists_in_this_repo
)
{
model_to_path
.
emplace
(
model
.
first
,
model_it
->
second
.
second
);
exists
=
true
;
}
else
{
LOG_ERROR
<<
"mapped path '"
<<
full_path
<<
"' does not exist for model '"
<<
model
.
first
<<
"'"
;
exists
=
false
;
}
}
else
{
for
(
const
auto
repository_path
:
repository_paths_
)
{
bool
exists_in_this_repo
=
false
;
const
auto
full_path
=
JoinPath
({
repository_path
,
model
.
first
});
Status
status
=
FileExists
(
full_path
,
&
exists_in_this_repo
);
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to poll model repository '"
<<
repository_path
<<
"' for model '"
<<
model
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
else
if
(
exists_in_this_repo
)
{
// Check to make sure this directory is not mapped.
// If mapped, continue to next repository path.
bool
mapped
=
false
;
for
(
auto
const
&
mapping
:
model_mappings_
)
{
if
(
mapping
.
second
.
second
==
full_path
)
{
mapped
=
true
;
break
;
}
}
if
(
mapped
)
{
continue
;
}
auto
res
=
model_to_path
.
emplace
(
model
.
first
,
JoinPath
({
repository_path
,
model
.
first
}));
if
(
res
.
second
)
{
exists
=
true
;
}
else
{
exists
=
false
;
model_to_path
.
erase
(
res
.
first
);
LOG_ERROR
<<
"failed to poll model '"
<<
model
.
first
<<
"': not unique across all model repositories"
;
break
;
}
}
}
}
// For an explicitly specified model that doesn't exist, we don't mark it
// as deleted, we simply mark that we couldn't poll all models.
if
(
!
exists
)
{
*
all_models_polled
=
false
;
}
}
}
// Poll each of the models. If error happens during polling the model,
// its state will fallback to the state before the polling.
for
(
const
auto
&
pair
:
model_to_path
)
{
std
::
unique_ptr
<
ModelInfo
>
model_info
;
const
auto
&
mit
=
models
.
find
(
pair
.
first
);
static
std
::
vector
<
const
InferenceParameter
*>
empty_params
;
auto
status
=
InitializeModelInfo
(
pair
.
first
,
pair
.
second
,
((
mit
==
models
.
end
())
?
empty_params
:
mit
->
second
),
&
model_info
);
const
auto
&
iitr
=
infos_
.
find
(
pair
.
first
);
const
bool
invalid_add
=
(
!
status
.
IsOk
())
&&
(
iitr
==
infos_
.
end
());
if
(
!
invalid_add
)
{
const
auto
&
ret
=
updated_infos
->
emplace
(
pair
.
first
,
nullptr
);
if
(
!
ret
.
second
)
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"unexpected model info for model '"
+
pair
.
first
+
"'"
);
}
// Classify load state and set updated info
if
(
model_info
==
nullptr
)
{
ret
.
first
->
second
.
reset
(
new
ModelInfo
(
*
iitr
->
second
));
unmodified
->
insert
(
pair
.
first
);
}
else
{
ret
.
first
->
second
=
std
::
move
(
model_info
);
if
(
iitr
!=
infos_
.
end
())
{
modified
->
insert
(
pair
.
first
);
}
else
{
added
->
insert
(
pair
.
first
);
}
}
}
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"Poll failed for model directory '"
<<
pair
.
first
<<
"': "
<<
status
.
Message
();
*
all_models_polled
=
false
;
}
}
return
Status
::
Success
;
}
bool
ModelRepositoryManager
::
ModelDirectoryOverride
(
const
std
::
vector
<
const
InferenceParameter
*>&
model_params
)
{
for
(
const
auto
&
param
:
model_params
)
{
if
(
param
->
Name
().
rfind
(
file_prefix
,
0
)
==
0
)
{
// param name starts with prefix if user provides override file
return
true
;
}
}
return
false
;
}
Status
ModelRepositoryManager
::
InitializeModelInfo
(
const
std
::
string
&
name
,
const
std
::
string
&
path
,
const
std
::
vector
<
const
InferenceParameter
*>&
params
,
std
::
unique_ptr
<
ModelInfo
>*
info
)
{
std
::
unique_ptr
<
ModelInfo
>
linfo
(
new
ModelInfo
());
linfo
->
model_path_
=
path
;
bool
unmodified
=
false
;
const
auto
iitr
=
infos_
.
find
(
name
);
// Set 'prev_mtime_ns_' if there is existing ModelInfo
if
(
iitr
!=
infos_
.
end
())
{
linfo
->
prev_mtime_ns_
=
iitr
->
second
->
mtime_nsec_
;
}
else
{
linfo
->
prev_mtime_ns_
=
0
;
}
// Set 'mtime_nsec_' and override 'model_path_' if current path is empty
// (file override is specified)
if
(
linfo
->
model_path_
.
empty
())
{
// Need to localize the override files, use repo agent to manage
// the lifecycle of the localized files
std
::
shared_ptr
<
TritonRepoAgent
>
localize_agent
(
new
LocalizeRepoAgent
());
std
::
unique_ptr
<
TritonRepoAgentModel
>
localize_agent_model
;
RETURN_IF_ERROR
(
TritonRepoAgentModel
::
Create
(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM
,
""
,
inference
::
ModelConfig
(),
localize_agent
,
{},
&
localize_agent_model
));
// Set agent model state so the repo agent can access the encoded files
// Using const_cast here but we are safe as the RepoAgent will not
// modify the state
localize_agent_model
->
SetState
(
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
&
params
)));
RETURN_IF_ERROR
(
localize_agent_model
->
InvokeAgent
(
TRITONREPOAGENT_ACTION_LOAD
));
const
char
*
location
;
TRITONREPOAGENT_ArtifactType
type
;
RETURN_IF_ERROR
(
localize_agent_model
->
Location
(
&
type
,
&
location
));
// For file override, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo
->
mtime_nsec_
=
0
;
linfo
->
model_path_
=
location
;
linfo
->
agent_model_list_
.
reset
(
new
TritonRepoAgentModelList
());
linfo
->
agent_model_list_
->
AddAgentModel
(
std
::
move
(
localize_agent_model
));
}
else
{
if
(
iitr
==
infos_
.
end
())
{
linfo
->
mtime_nsec_
=
GetModifiedTime
(
std
::
string
(
linfo
->
model_path_
));
}
else
{
// Check the current timestamps to determine if model actually has been
// modified
linfo
->
mtime_nsec_
=
linfo
->
prev_mtime_ns_
;
unmodified
=
!
IsModified
(
std
::
string
(
linfo
->
model_path_
),
&
linfo
->
mtime_nsec_
);
}
}
// Set 'model_config_'
bool
parsed_config
=
false
;
// Check if there is config override
for
(
const
auto
&
override_parameter
:
params
)
{
if
((
override_parameter
->
Name
()
==
"config"
)
&&
(
override_parameter
->
Type
()
==
TRITONSERVER_PARAMETER_STRING
))
{
// When override happens, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo
->
mtime_nsec_
=
0
;
unmodified
=
false
;
const
std
::
string
&
override_config
=
override_parameter
->
ValueString
();
auto
err
=
JsonToModelConfig
(
override_config
,
1
/* config_version */
,
&
linfo
->
model_config_
);
if
(
!
err
.
IsOk
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Invalid config override: "
+
std
::
string
(
err
.
Message
()));
}
parsed_config
=
true
;
break
;
}
else
if
(
override_parameter
->
Name
().
rfind
(
file_prefix
,
0
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Unrecognized load parameter '"
+
override_parameter
->
Name
()
+
"' with type '"
+
TRITONSERVER_ParameterTypeString
(
override_parameter
->
Type
())
+
"'"
);
}
}
// Polling model is considered unmodified by this point and can be returned
// with info == nullptr
if
(
unmodified
)
{
return
Status
::
Success
;
}
// Create the associated repo agent models when a model is to be loaded,
// this must be done before normalizing model config as agents might
// redirect to use the model config at a different location
if
(
!
parsed_config
)
{
const
auto
config_path
=
JoinPath
({
linfo
->
model_path_
,
kModelConfigPbTxt
});
bool
model_config_exists
=
false
;
RETURN_IF_ERROR
(
FileExists
(
config_path
,
&
model_config_exists
));
// model config can be missing if auto fill is set
if
(
autofill_
&&
!
model_config_exists
)
{
linfo
->
model_config_
.
Clear
();
}
else
{
RETURN_IF_ERROR
(
ReadTextProto
(
config_path
,
&
linfo
->
model_config_
));
parsed_config
=
true
;
}
}
if
(
parsed_config
)
{
RETURN_IF_ERROR
(
CreateAgentModelListWithLoadAction
(
linfo
->
model_config_
,
linfo
->
model_path_
,
&
linfo
->
agent_model_list_
));
if
(
linfo
->
agent_model_list_
!=
nullptr
)
{
// Get the latest repository path
const
char
*
location
;
TRITONREPOAGENT_ArtifactType
artifact_type
;
RETURN_IF_ERROR
(
linfo
->
agent_model_list_
->
Back
()
->
Location
(
&
artifact_type
,
&
location
));
auto
latest_path
=
std
::
string
(
location
);
linfo
->
model_path_
=
latest_path
;
}
}
linfo
->
is_config_provided_
=
parsed_config
;
// Try to automatically generate missing parts of the model
// configuration (autofill) that don't require model detail
RETURN_IF_ERROR
(
GetNormalizedModelConfig
(
name
,
linfo
->
model_path_
,
min_compute_capability_
,
&
linfo
->
model_config_
));
// Note that the model inputs and outputs are not validated until
// the model model is intialized as they may not be auto-completed
// until model is intialized.
RETURN_IF_ERROR
(
ValidateModelConfig
(
linfo
->
model_config_
,
min_compute_capability_
));
if
(
!
autofill_
)
{
RETURN_IF_ERROR
(
ValidateModelIOConfig
(
linfo
->
model_config_
));
}
// If the model is mapped, update its config name based on the
// mapping.
if
(
model_mappings_
.
find
(
name
)
!=
model_mappings_
.
end
())
{
linfo
->
model_config_
.
set_name
(
name
);
}
else
{
// If there is no model mapping, make sure the name of the model
// matches the name of the directory. This is a somewhat arbitrary
// requirement but seems like good practice to require it of the user.
// It also acts as a check to make sure we don't have two different
// models with the same name.
if
(
linfo
->
model_config_
.
name
()
!=
name
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"unexpected directory name '"
+
name
+
"' for model '"
+
linfo
->
model_config_
.
name
()
+
"', directory name must equal model name"
);
}
}
*
info
=
std
::
move
(
linfo
);
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
UpdateDependencyGraph
(
const
std
::
set
<
std
::
string
>&
added
,
const
std
::
set
<
std
::
string
>&
deleted
,
const
std
::
set
<
std
::
string
>&
modified
,
std
::
set
<
std
::
string
>*
deleted_dependents
)
{
// update dependency graph, if the state of a node is changed, all its
// downstreams will be affected
// deleted, drop from dependency_graph, add to missing_nodes if downstreams is
// not empty affected_nodes are all ensembles as only ensembles are depending
// on other models
std
::
set
<
DependencyNode
*>
affected_nodes
;
std
::
set
<
DependencyNode
*>
updated_nodes
;
std
::
set
<
std
::
string
>
current_deleted
=
deleted
;
while
(
!
current_deleted
.
empty
())
{
std
::
set
<
std
::
string
>
next_deleted
;
for
(
const
auto
&
model_name
:
current_deleted
)
{
auto
it
=
dependency_graph_
.
find
(
model_name
);
if
(
it
!=
dependency_graph_
.
end
())
{
// remove this node from its upstreams
for
(
auto
&
upstream
:
it
->
second
->
upstreams_
)
{
upstream
.
first
->
downstreams_
.
erase
(
it
->
second
.
get
());
// Check if the upstream should be removed as well
if
((
deleted_dependents
!=
nullptr
)
&&
(
upstream
.
first
->
downstreams_
.
empty
())
&&
(
!
upstream
.
first
->
explicitly_load_
))
{
next_deleted
.
emplace
(
upstream
.
first
->
model_name_
);
}
}
it
->
second
->
upstreams_
.
clear
();
if
(
!
it
->
second
->
downstreams_
.
empty
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
// mark this node as missing upstream in its downstreams
for
(
auto
&
downstream
:
it
->
second
->
downstreams_
)
{
downstream
->
missing_upstreams_
.
emplace
(
it
->
second
.
get
());
}
missing_nodes_
.
emplace
(
std
::
make_pair
(
model_name
,
std
::
move
(
it
->
second
)));
}
// Make sure deleted node will not be in affected nodes
affected_nodes
.
erase
(
it
->
second
.
get
());
dependency_graph_
.
erase
(
it
);
}
if
(
deleted_dependents
!=
nullptr
)
{
deleted_dependents
->
emplace
(
model_name
);
}
}
current_deleted
.
swap
(
next_deleted
);
}
// modified, invalidate (uncheck) all downstreams
for
(
const
auto
&
model_name
:
modified
)
{
auto
it
=
dependency_graph_
.
find
(
model_name
);
if
(
it
!=
dependency_graph_
.
end
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
ModelInfo
*
info
=
nullptr
;
GetModelInfo
(
model_name
,
&
info
);
it
->
second
->
model_config_
=
info
->
model_config_
;
it
->
second
->
explicitly_load_
=
info
->
explicitly_load_
;
// remove this node from its upstream node
for
(
auto
&
upstream
:
it
->
second
->
upstreams_
)
{
upstream
.
first
->
downstreams_
.
erase
(
it
->
second
.
get
());
}
it
->
second
->
upstreams_
.
clear
();
it
->
second
->
checked_
=
false
;
it
->
second
->
status_
=
Status
::
Success
;
updated_nodes
.
emplace
(
it
->
second
.
get
());
}
}
// added, add to dependency_graph, if in missing_node, invalidate (uncheck)
// and associate all downstreams, remove from missing_node
for
(
const
auto
&
model_name
:
added
)
{
std
::
unique_ptr
<
DependencyNode
>
added_node
;
auto
it
=
missing_nodes_
.
find
(
model_name
);
if
(
it
!=
missing_nodes_
.
end
())
{
UncheckDownstream
(
&
it
->
second
->
downstreams_
,
&
affected_nodes
);
// remove this node from missing upstream node in its downstream nodes
for
(
auto
&
downstream
:
it
->
second
->
downstreams_
)
{
downstream
->
missing_upstreams_
.
erase
(
it
->
second
.
get
());
}
it
->
second
->
checked_
=
false
;
added_node
=
std
::
move
(
it
->
second
);
missing_nodes_
.
erase
(
it
);
}
else
{
// Right now, nothing is going to be filled until validation
added_node
.
reset
(
new
DependencyNode
(
model_name
));
}
ModelInfo
*
info
=
nullptr
;
GetModelInfo
(
model_name
,
&
info
);
added_node
->
model_config_
=
info
->
model_config_
;
added_node
->
explicitly_load_
=
info
->
explicitly_load_
;
updated_nodes
.
emplace
(
added_node
.
get
());
dependency_graph_
.
emplace
(
std
::
make_pair
(
model_name
,
std
::
move
(
added_node
)));
}
auto
&
affected_ensembles
=
affected_nodes
;
for
(
auto
&
updated_node
:
updated_nodes
)
{
bool
is_ensemble
=
ConnectDependencyGraph
(
updated_node
);
if
(
is_ensemble
)
{
affected_ensembles
.
emplace
(
updated_node
);
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// After the dependency graph is updated, check ensemble dependencies
for
(
auto
&
ensemble
:
affected_ensembles
)
{
if
(
ensemble
->
status_
.
IsOk
())
{
if
(
!
ensemble
->
missing_upstreams_
.
empty
())
{
std
::
string
name_list
;
for
(
auto
it
=
ensemble
->
missing_upstreams_
.
begin
();
it
!=
ensemble
->
missing_upstreams_
.
end
();
it
++
)
{
if
(
it
!=
ensemble
->
missing_upstreams_
.
begin
())
{
name_list
+=
", "
;
}
name_list
+=
(
*
it
)
->
model_name_
;
}
ensemble
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble "
+
ensemble
->
model_name_
+
" contains models that are not available: "
+
name_list
);
}
else
{
ensemble
->
status_
=
CircularcyCheck
(
ensemble
,
ensemble
);
}
}
}
#endif // TRITON_ENABLE_ENSEMBLE
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
RegisterModelRepository
(
const
std
::
string
&
repository
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
model_mapping
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"repository registration is not allowed if model control mode is not "
"EXPLICIT"
);
}
bool
is_directory
=
false
;
auto
status
=
IsDirectory
(
repository
,
&
is_directory
);
if
(
!
status
.
IsOk
()
||
!
is_directory
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
(
std
::
string
(
"failed to register '"
)
+
repository
+
"', repository not found"
)
.
c_str
());
}
{
// Serialize all operations that change model state
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
// Check repository and mapped models do not yet exist.
if
(
repository_paths_
.
find
(
repository
)
!=
repository_paths_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"model repository '"
+
repository
+
"' has already been registered"
);
}
for
(
const
auto
&
mapping
:
model_mapping
)
{
if
(
model_mappings_
.
find
(
mapping
.
first
)
!=
model_mappings_
.
end
())
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
(
std
::
string
(
"failed to register '"
)
+
mapping
.
first
+
"', there is a conflicting mapping for '"
+
std
::
string
(
mapping
.
first
)
+
"'"
)
.
c_str
());
}
}
repository_paths_
.
emplace
(
repository
);
for
(
const
auto
&
mapping
:
model_mapping
)
{
model_mappings_
.
emplace
(
mapping
.
first
,
std
::
make_pair
(
repository
,
JoinPath
({
repository
,
mapping
.
second
})));
}
}
LOG_INFO
<<
"Model repository registered: "
<<
repository
;
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
UnregisterModelRepository
(
const
std
::
string
&
repository
)
{
if
(
!
model_control_enabled_
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"repository unregistration is not allowed if model control mode is not "
"EXPLICIT"
);
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
poll_mu_
);
if
(
repository_paths_
.
erase
(
repository
)
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"failed to unregister '"
+
repository
+
"', repository not found"
);
}
std
::
set
<
std
::
string
>
models_to_delete
;
for
(
auto
const
&
mapping
:
model_mappings_
)
{
if
(
mapping
.
second
.
first
==
repository
)
{
models_to_delete
.
insert
(
mapping
.
first
);
}
}
for
(
auto
const
&
model
:
models_to_delete
)
{
model_mappings_
.
erase
(
model
);
}
}
LOG_INFO
<<
"Model repository unregistered: "
<<
repository
;
return
Status
::
Success
;
}
Status
ModelRepositoryManager
::
CircularcyCheck
(
DependencyNode
*
current_node
,
const
DependencyNode
*
start_node
)
{
for
(
auto
&
downstream
:
current_node
->
downstreams_
)
{
if
(
downstream
->
model_name_
==
start_node
->
model_name_
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"circular dependency between ensembles: "
+
start_node
->
model_name_
+
" -> ... -> "
+
current_node
->
model_name_
+
" -> "
+
start_node
->
model_name_
);
}
else
{
const
auto
status
=
CircularcyCheck
(
downstream
,
start_node
);
if
(
!
status
.
IsOk
()
&&
current_node
->
status_
.
IsOk
())
{
current_node
->
status_
=
status
;
return
status
;
}
}
}
return
Status
::
Success
;
}
void
ModelRepositoryManager
::
UncheckDownstream
(
NodeSet
*
downstreams
,
NodeSet
*
updated_nodes
)
{
// Mark downstream nodes as unchecked recursively
for
(
auto
&
node
:
*
downstreams
)
{
if
(
node
->
checked_
)
{
node
->
checked_
=
false
;
node
->
status_
=
Status
::
Success
;
UncheckDownstream
(
&
node
->
downstreams_
,
updated_nodes
);
updated_nodes
->
emplace
(
node
);
}
}
}
bool
ModelRepositoryManager
::
ConnectDependencyGraph
(
DependencyNode
*
updated_node
)
{
// Check the node's model config to determine if it depends on other models
// and if those models are present
updated_node
->
upstreams_
.
clear
();
updated_node
->
missing_upstreams_
.
clear
();
if
(
updated_node
->
model_config_
.
has_ensemble_scheduling
())
{
for
(
const
auto
&
step
:
updated_node
->
model_config_
.
ensemble_scheduling
().
step
())
{
DependencyNode
*
upstream_node
=
nullptr
;
const
auto
&
model_name
=
step
.
model_name
();
auto
dit
=
dependency_graph_
.
find
(
model_name
);
if
(
dit
==
dependency_graph_
.
end
())
{
auto
mit
=
missing_nodes_
.
find
(
model_name
);
if
(
mit
==
missing_nodes_
.
end
())
{
std
::
unique_ptr
<
DependencyNode
>
node
(
new
DependencyNode
(
model_name
));
updated_node
->
missing_upstreams_
.
emplace
(
node
.
get
());
mit
=
missing_nodes_
.
emplace
(
model_name
,
std
::
move
(
node
)).
first
;
}
// Add the node to missing node's downstream so that when the missing
// node is added, the downstreams can be found easily.
mit
->
second
->
downstreams_
.
emplace
(
updated_node
);
upstream_node
=
mit
->
second
.
get
();
}
else
{
dit
->
second
->
downstreams_
.
emplace
(
updated_node
);
upstream_node
=
dit
->
second
.
get
();
}
auto
res
=
updated_node
->
upstreams_
.
emplace
(
upstream_node
,
std
::
set
<
int64_t
>
({
step
.
model_version
()}));
// If map insertion doesn't happen, the same model is required in
// different step, insert the version to existing required version set.
if
(
!
res
.
second
)
{
res
.
first
->
second
.
insert
(
step
.
model_version
());
}
}
return
true
;
}
return
false
;
}
Status
ModelRepositoryManager
::
GetModelInfo
(
const
std
::
string
&
name
,
ModelInfo
**
model_info
)
{
const
auto
itr
=
infos_
.
find
(
name
);
if
(
itr
==
infos_
.
end
())
{
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"no configuration for model '"
+
name
+
"'"
);
}
*
model_info
=
itr
->
second
.
get
();
return
Status
::
Success
;
}
std
::
pair
<
ModelRepositoryManager
::
NodeSet
,
ModelRepositoryManager
::
NodeSet
>
ModelRepositoryManager
::
ModelsToLoadUnload
(
const
NodeSet
&
loaded_models
)
{
// <valid model set, invalid model set>
std
::
pair
<
NodeSet
,
NodeSet
>
res
;
// first call to this function
if
(
loaded_models
.
empty
())
{
for
(
auto
&
pair
:
dependency_graph_
)
{
auto
node
=
pair
.
second
.
get
();
// only care about nodes that are affected by the update
if
(
!
node
->
checked_
)
{
if
(
CheckNode
(
node
))
{
if
(
node
->
status_
.
IsOk
())
{
res
.
first
.
emplace
(
node
);
}
else
{
res
.
second
.
emplace
(
node
);
}
}
}
}
}
else
{
for
(
const
auto
&
model
:
loaded_models
)
{
for
(
auto
node
:
model
->
downstreams_
)
{
// only care about nodes that are affected by the update
if
(
!
node
->
checked_
)
{
if
(
CheckNode
(
node
))
{
if
(
node
->
status_
.
IsOk
())
{
res
.
first
.
emplace
(
node
);
}
else
{
res
.
second
.
emplace
(
node
);
}
}
}
}
}
}
for
(
auto
&
node
:
res
.
first
)
{
node
->
checked_
=
true
;
}
for
(
auto
&
node
:
res
.
second
)
{
node
->
checked_
=
true
;
}
return
res
;
}
bool
ModelRepositoryManager
::
CheckNode
(
DependencyNode
*
node
)
{
bool
node_ready
=
true
;
// if the node is in invalid status, mark as ready as we know
// it should not be loaded
if
(
node
->
status_
.
IsOk
())
{
for
(
auto
&
upstream
:
node
->
upstreams_
)
{
if
(
!
upstream
.
first
->
checked_
)
{
node_ready
=
false
;
break
;
}
if
(
!
upstream
.
first
->
status_
.
IsOk
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' which is not valid"
);
}
else
if
(
upstream
.
first
->
loaded_versions_
.
empty
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' which has no loaded version"
);
}
else
{
for
(
const
auto
&
required_version
:
upstream
.
second
)
{
if
(
required_version
==
-
1
)
{
continue
;
}
auto
it
=
upstream
.
first
->
loaded_versions_
.
find
(
required_version
);
if
(
it
==
upstream
.
first
->
loaded_versions_
.
end
())
{
node
->
status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble '"
+
node
->
model_name_
+
"' depends on '"
+
upstream
.
first
->
model_name_
+
"' whose required version "
+
std
::
to_string
(
required_version
)
+
" is not loaded"
);
}
}
}
if
(
!
node
->
status_
.
IsOk
())
{
break
;
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// Validate ensemble config if the node is ready. By this point, the
// depending models are loaded and their configs are completed
if
(
node_ready
&&
node
->
status_
.
IsOk
())
{
node
->
status_
=
ValidateEnsembleConfig
(
this
,
node
);
}
#endif // TRITON_ENABLE_ENSEMBLE
}
return
node_ready
;
}
}}
// namespace triton::core
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment