Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
b30f3cdb
Commit
b30f3cdb
authored
Nov 14, 2023
by
xiabo
Browse files
添加下载的代码
parent
e38ee081
Changes
157
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
9042 additions
and
0 deletions
+9042
-0
3rdparty/core-r22.12/src/ensemble_model.cc
3rdparty/core-r22.12/src/ensemble_model.cc
+67
-0
3rdparty/core-r22.12/src/ensemble_model.h
3rdparty/core-r22.12/src/ensemble_model.h
+60
-0
3rdparty/core-r22.12/src/ensemble_scheduler.cc
3rdparty/core-r22.12/src/ensemble_scheduler.cc
+1390
-0
3rdparty/core-r22.12/src/ensemble_scheduler.h
3rdparty/core-r22.12/src/ensemble_scheduler.h
+123
-0
3rdparty/core-r22.12/src/ensemble_utils.cc
3rdparty/core-r22.12/src/ensemble_utils.cc
+370
-0
3rdparty/core-r22.12/src/ensemble_utils.h
3rdparty/core-r22.12/src/ensemble_utils.h
+50
-0
3rdparty/core-r22.12/src/filesystem.cc
3rdparty/core-r22.12/src/filesystem.cc
+2662
-0
3rdparty/core-r22.12/src/filesystem.h
3rdparty/core-r22.12/src/filesystem.h
+224
-0
3rdparty/core-r22.12/src/infer_parameter.cc
3rdparty/core-r22.12/src/infer_parameter.cc
+61
-0
3rdparty/core-r22.12/src/infer_parameter.h
3rdparty/core-r22.12/src/infer_parameter.h
+102
-0
3rdparty/core-r22.12/src/infer_request.cc
3rdparty/core-r22.12/src/infer_request.cc
+1498
-0
3rdparty/core-r22.12/src/infer_request.h
3rdparty/core-r22.12/src/infer_request.h
+800
-0
3rdparty/core-r22.12/src/infer_response.cc
3rdparty/core-r22.12/src/infer_response.cc
+431
-0
3rdparty/core-r22.12/src/infer_response.h
3rdparty/core-r22.12/src/infer_response.h
+351
-0
3rdparty/core-r22.12/src/infer_stats.cc
3rdparty/core-r22.12/src/infer_stats.cc
+241
-0
3rdparty/core-r22.12/src/infer_stats.h
3rdparty/core-r22.12/src/infer_stats.h
+190
-0
3rdparty/core-r22.12/src/infer_trace.cc
3rdparty/core-r22.12/src/infer_trace.cc
+61
-0
3rdparty/core-r22.12/src/infer_trace.h
3rdparty/core-r22.12/src/infer_trace.h
+205
-0
3rdparty/core-r22.12/src/instance_queue.cc
3rdparty/core-r22.12/src/instance_queue.cc
+99
-0
3rdparty/core-r22.12/src/instance_queue.h
3rdparty/core-r22.12/src/instance_queue.h
+57
-0
No files found.
Too many changes to show.
To preserve performance only
157 of 157+
files are displayed.
Plain diff
Email patch
3rdparty/core-r22.12/src/ensemble_model.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ensemble_model.h"
#include <stdint.h>
#include "constants.h"
#include "ensemble_scheduler.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
Status
EnsembleModel
::
Create
(
InferenceServer
*
server
,
const
std
::
string
&
path
,
const
int64_t
version
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
Model
>*
model
)
{
// Create the ensemble model.
std
::
unique_ptr
<
EnsembleModel
>
local_model
(
new
EnsembleModel
(
min_compute_capability
,
path
,
version
,
model_config
));
RETURN_IF_ERROR
(
local_model
->
Init
(
is_config_provided
));
std
::
unique_ptr
<
Scheduler
>
scheduler
;
RETURN_IF_ERROR
(
EnsembleScheduler
::
Create
(
local_model
->
MutableStatsAggregator
(),
server
,
model_config
,
&
scheduler
));
RETURN_IF_ERROR
(
local_model
->
SetScheduler
(
std
::
move
(
scheduler
)));
LOG_VERBOSE
(
1
)
<<
"ensemble model for "
<<
local_model
->
Name
()
<<
std
::
endl
;
*
model
=
std
::
move
(
local_model
);
return
Status
::
Success
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
EnsembleModel
&
pb
)
{
out
<<
"name="
<<
pb
.
Name
()
<<
std
::
endl
;
return
out
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/ensemble_model.h
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "model.h"
#include "model_config.pb.h"
#include "scheduler.h"
#include "status.h"
namespace
triton
{
namespace
core
{
class
InferenceServer
;
class
EnsembleModel
:
public
Model
{
public:
EnsembleModel
(
EnsembleModel
&&
)
=
default
;
static
Status
Create
(
InferenceServer
*
server
,
const
std
::
string
&
path
,
const
int64_t
version
,
const
inference
::
ModelConfig
&
model_config
,
const
bool
is_config_provided
,
const
double
min_compute_capability
,
std
::
unique_ptr
<
Model
>*
model
);
private:
DISALLOW_COPY_AND_ASSIGN
(
EnsembleModel
);
explicit
EnsembleModel
(
const
double
min_compute_capability
,
const
std
::
string
&
model_dir
,
const
int64_t
version
,
const
inference
::
ModelConfig
&
config
)
:
Model
(
min_compute_capability
,
model_dir
,
version
,
config
)
{
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
EnsembleModel
&
);
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
EnsembleModel
&
pb
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/ensemble_scheduler.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_scheduler.h"
#include <mutex>
#include "cuda_utils.h"
#include "metrics.h"
#include "model.h"
#include "model_config_utils.h"
#include "server.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
namespace
{
class
EnsembleContext
;
using
IterationCount
=
size_t
;
// Request tracker is passed as 'userp' in RequestRelease function and used
// to manage the lifecycle of the ensemble request
class
RequestTracker
{
public:
explicit
RequestTracker
(
std
::
unique_ptr
<
InferenceRequest
>&&
request
,
uint64_t
compute_start_ns
,
MetricModelReporter
*
metric_reporter
,
InferenceStatsAggregator
*
stats_aggregator
)
:
inflight_request_counter_
(
1
),
request_
(
std
::
move
(
request
)),
compute_start_ns_
(
compute_start_ns
),
metric_reporter_
(
metric_reporter
),
stats_aggregator_
(
stats_aggregator
),
status_
(
Status
::
Success
)
{
}
std
::
unique_ptr
<
InferenceRequest
>&
Request
()
{
return
request_
;
}
InferenceStatsAggregator
&
ContextStatsAggregator
()
{
return
context_stats_aggregator_
;
}
void
IncrementCounter
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mtx_
);
inflight_request_counter_
++
;
}
bool
DecrementCounter
()
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mtx_
);
inflight_request_counter_
--
;
if
(
inflight_request_counter_
==
0
)
{
#ifdef TRITON_ENABLE_STATS
const
auto
&
infer_stats
=
context_stats_aggregator_
.
ImmutableInferStats
();
request_
->
ReportStatisticsWithDuration
(
metric_reporter_
,
status_
.
IsOk
(),
compute_start_ns_
,
infer_stats
.
compute_input_duration_ns_
,
infer_stats
.
compute_infer_duration_ns_
,
infer_stats
.
compute_output_duration_ns_
);
if
(
status_
.
IsOk
())
{
stats_aggregator_
->
UpdateInferBatchStatsWithDuration
(
metric_reporter_
,
std
::
max
(
1U
,
request_
->
BatchSize
()),
infer_stats
.
compute_input_duration_ns_
,
infer_stats
.
compute_infer_duration_ns_
,
infer_stats
.
compute_output_duration_ns_
);
}
#endif
InferenceRequest
::
Release
(
std
::
move
(
request_
),
TRITONSERVER_REQUEST_RELEASE_ALL
);
}
return
(
inflight_request_counter_
==
0
);
}
void
SetStatus
(
const
Status
&
status
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mtx_
);
status_
=
status
;
}
private:
std
::
mutex
mtx_
;
uint32_t
inflight_request_counter_
;
std
::
unique_ptr
<
InferenceRequest
>
request_
;
uint64_t
compute_start_ns_
;
MetricModelReporter
*
metric_reporter_
;
InferenceStatsAggregator
*
stats_aggregator_
;
InferenceStatsAggregator
context_stats_aggregator_
;
Status
status_
;
};
// Step is used as 'userp' and keeps ensemble context alive
// until no more internal requests are inflight.
// Step contains metadata, and status for the
// internal infer request
struct
Step
{
Step
(
size_t
step_idx
,
const
InferenceRequest
::
SequenceId
&
correlation_id
,
uint32_t
flags
)
:
correlation_id_
(
correlation_id
),
flags_
(
flags
),
response_flags_
(
0
),
infer_status_
(
nullptr
),
step_idx_
(
step_idx
)
{
}
std
::
shared_ptr
<
EnsembleContext
>
ctx_
;
std
::
unique_ptr
<
InferenceRequest
>
request_
;
InferenceRequest
::
SequenceId
correlation_id_
;
uint32_t
flags_
;
std
::
mutex
output_mtx_
;
// Different output map to avoid address conflict from different memory types
std
::
unordered_map
<
uintptr_t
,
std
::
shared_ptr
<
AllocatedMemory
>>
cpu_output_map_
;
std
::
unordered_map
<
int64_t
,
std
::
unordered_map
<
uintptr_t
,
std
::
shared_ptr
<
AllocatedMemory
>>>
gpu_output_map_
;
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>
updated_tensors_
;
uint32_t
response_flags_
;
TRITONSERVER_Error
*
infer_status_
;
size_t
step_idx_
;
};
struct
TensorData
{
struct
Metadata
{
Metadata
()
=
default
;
Metadata
(
std
::
unique_ptr
<
InferenceRequest
::
Input
>&&
data
,
size_t
reference_count
)
:
data_
(
std
::
move
(
data
)),
remaining_reference_count_
(
reference_count
),
parameter_override_
(
false
)
{
}
Metadata
(
std
::
unique_ptr
<
InferenceRequest
::
Input
>&&
data
,
size_t
reference_count
,
const
InferenceRequest
::
SequenceId
&
correlation_id
,
uint32_t
flags
)
:
data_
(
std
::
move
(
data
)),
remaining_reference_count_
(
reference_count
),
parameter_override_
(
true
),
correlation_id_
(
correlation_id
),
flags_
(
flags
)
{
}
std
::
unique_ptr
<
InferenceRequest
::
Input
>
data_
;
size_t
remaining_reference_count_
;
bool
parameter_override_
;
InferenceRequest
::
SequenceId
correlation_id_
;
uint32_t
flags_
;
};
TensorData
()
=
default
;
TensorData
(
const
size_t
outgoing_steps_count
)
:
current_iteration_
(
0
),
outgoing_steps_count_
(
outgoing_steps_count
),
batch_size_
(
0
)
{
}
IterationCount
AddTensor
(
std
::
unique_ptr
<
InferenceRequest
::
Input
>&&
tensor
)
{
tensor_
.
emplace
(
current_iteration_
,
Metadata
(
std
::
move
(
tensor
),
outgoing_steps_count_
));
return
current_iteration_
++
;
}
IterationCount
AddTensor
(
std
::
unique_ptr
<
InferenceRequest
::
Input
>&&
tensor
,
const
InferenceRequest
::
SequenceId
&
correlation_id
,
uint32_t
flags
)
{
tensor_
.
emplace
(
current_iteration_
,
Metadata
(
std
::
move
(
tensor
),
outgoing_steps_count_
,
correlation_id
,
flags
));
return
current_iteration_
++
;
}
// Tensors associated with the particular ensemble tensor.
// A container is used to handle the decoupled case
// where variable number of tensors will be produced.
// map 'iteration count' to pair of <tensor, remaining outgoing count>
std
::
unordered_map
<
IterationCount
,
Metadata
>
tensor_
;
size_t
current_iteration_
;
size_t
outgoing_steps_count_
;
// Ensemble may be configured to passing tensor between batching model and
// non-batching model as long as the full shapes match and storing the batch
// size of the generated tensor explicitly for checking and setting proper
// shape for the downstream model request.
size_t
batch_size_
;
};
// EnsembleContext maintains the state of the ensemble request
//
// Using static functions to take advantage of shared_ptr, a copy of the
// shared_ptr will be made when a step is scheduled and it will go out of
// scope after the step's callback is finished. The step's callback will
// schedule new steps if available and the last step will finish the ensemble
// request.
// So we don't have to maintian the context in scheduler as the shared_ptr
// will destroy the context for us if there are no "in-flight" steps.
class
EnsembleContext
{
public:
EnsembleContext
(
MetricModelReporter
*
metric_reporter
,
InferenceStatsAggregator
*
stats_aggregator
,
InferenceServer
*
is
,
EnsembleInfo
*
info
,
std
::
unique_ptr
<
InferenceRequest
>&
request
,
cudaStream_t
stream
);
// Perform transition on 'context' state given the information of
// 'completed_step'
static
void
Proceed
(
const
std
::
shared_ptr
<
EnsembleContext
>&
context
,
const
std
::
unique_ptr
<
Step
>&
completed_step
=
nullptr
);
private:
static
TRITONSERVER_Error
*
ResponseAlloc
(
TRITONSERVER_ResponseAllocator
*
allocator
,
const
char
*
tensor_name
,
size_t
byte_size
,
TRITONSERVER_MemoryType
preferred_memory_type
,
int64_t
preferred_memory_type_id
,
void
*
userp
,
void
**
buffer
,
void
**
buffer_userp
,
TRITONSERVER_MemoryType
*
allocated_memory_type
,
int64_t
*
allocated_memory_type_id
);
static
TRITONSERVER_Error
*
ResponseRelease
(
TRITONSERVER_ResponseAllocator
*
allocator
,
void
*
buffer
,
void
*
buffer_userp
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
static
TRITONSERVER_Error
*
OutputBufferQuery
(
TRITONSERVER_ResponseAllocator
*
allocator
,
void
*
userp
,
const
char
*
tensor_name
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
);
static
void
RequestComplete
(
TRITONSERVER_InferenceRequest
*
request
,
const
uint32_t
flags
,
void
*
userp
);
static
void
ResponseComplete
(
TRITONSERVER_InferenceResponse
*
response
,
const
uint32_t
flags
,
void
*
userp
);
using
StepList
=
std
::
vector
<
std
::
unique_ptr
<
Step
>>
;
using
VersionMap
=
std
::
unordered_map
<
int64_t
,
std
::
shared_ptr
<
Model
>>
;
// Helper function to reshape the given tensor according to the
// config shape and batching info and its actual shape and batching info.
// Note that 'dims' will be in full shape as opposed to 'config_dims'.
// Return the dims after reshape.
std
::
vector
<
int64_t
>
ReshapeTensorDims
(
const
triton
::
common
::
DimsList
&
config_dims
,
const
bool
config_allow_batching
,
const
size_t
tensor_batch_size
,
const
std
::
vector
<
int64_t
>&
dims
);
// Return the list of step that becomes ready due to tensor update
// from 'completed_step'
Status
PrepareSteps
(
const
std
::
unique_ptr
<
Step
>&
completed_step
,
StepList
*
steps
);
// Prepare infer stats and call the inference server's function to process
// the infer requests specified in 'steps'
static
void
ScheduleSteps
(
const
std
::
shared_ptr
<
EnsembleContext
>&
context
,
StepList
&&
steps
);
// Helper function that updates ensemble state given 'completed_step' and
// returns the list of updated tensors in 'updated_tensors'
Status
UpdateEnsembleState
(
const
std
::
unique_ptr
<
Step
>&
completed_step
,
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>*
updated_tensors
);
// Helper function that returns a list of 'steps' that should be run under
// current ensemble state. 'updated_tensors' is used so that we don't need to
// iterate all the tensors to determine which step can be run.
Status
GetNextSteps
(
const
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>&
updated_tensors
,
StepList
*
steps
);
// Helper function that completes the response of the ensemble request
Status
FinishEnsemble
(
std
::
unique_ptr
<
InferenceResponse
>&&
response
=
nullptr
);
// Helper function that initialize the 'step' given the info at 'step_idx'.
// The 'step' will have proper request / response provider for the model
Status
InitStep
(
const
size_t
step_idx
,
const
IterationCount
iteration_count
,
std
::
unique_ptr
<
Step
>*
step
);
// Helper function that set the output of the ensemble request if it is ready
// and valid.
Status
CheckAndSetEnsembleOutput
(
const
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>&
updated_tensors
,
std
::
unique_ptr
<
InferenceResponse
>*
response
);
InferenceServer
*
is_
;
EnsembleInfo
*
info_
;
// All EnsembleContext will use the same CUDA stream managed by
// the ensemble scheduler
cudaStream_t
stream_
;
// Mutex to avoid concurrent call on 'PrepareSteps' where ensemble state
// are being modified
std
::
mutex
mutex_
;
size_t
inflight_step_counter_
;
// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std
::
unordered_map
<
std
::
string
,
std
::
set
<
size_t
>>*
tensor_to_step_
;
std
::
unordered_map
<
std
::
string
,
std
::
set
<
size_t
>>
pruned_tensor_to_step_
;
std
::
unordered_map
<
std
::
string
,
TensorData
>
tensor_data_
;
// Handle to all models that may be used in the ensemble
std
::
unordered_map
<
std
::
string
,
VersionMap
>
handles_
;
// Request specific information that obtained from ensemble request and
// should be applied to all internal requests
uint32_t
flags_
;
std
::
string
request_id_
;
InferenceRequest
::
SequenceId
correlation_id_
;
uint32_t
priority_
;
uint64_t
timeout_
;
// Objects related to the ensemble infer request
Status
ensemble_status_
;
RequestTracker
*
request_tracker_
;
// The allocator that will be used to allocate buffers for the
// inference result tensors.
std
::
unique_ptr
<
TRITONSERVER_ResponseAllocator
,
decltype
(
&
TRITONSERVER_ResponseAllocatorDelete
)
>
allocator_
;
};
EnsembleContext
::
EnsembleContext
(
MetricModelReporter
*
metric_reporter
,
InferenceStatsAggregator
*
stats_aggregator
,
InferenceServer
*
is
,
EnsembleInfo
*
info
,
std
::
unique_ptr
<
InferenceRequest
>&
request
,
cudaStream_t
stream
)
:
is_
(
is
),
info_
(
info
),
stream_
(
stream
),
inflight_step_counter_
(
0
),
allocator_
(
nullptr
,
TRITONSERVER_ResponseAllocatorDelete
)
{
uint64_t
compute_start_ns
=
0
;
INFER_STATS_SET_TIMESTAMP
(
compute_start_ns
);
request_tracker_
=
new
RequestTracker
(
std
::
move
(
request
),
compute_start_ns
,
metric_reporter
,
stats_aggregator
);
auto
&
lrequest
=
request_tracker_
->
Request
();
// Obtain model handles of all models in ensemble request such that
// they have the same lifetime as the ensemble request to avoid unloading
// while the ensemble is executing.
for
(
const
auto
&
step_info
:
info_
->
steps_
)
{
auto
it
=
handles_
.
find
(
step_info
.
model_name_
);
if
(
it
==
handles_
.
end
())
{
it
=
handles_
.
emplace
(
std
::
make_pair
(
step_info
.
model_name_
,
VersionMap
()))
.
first
;
}
auto
ver_it
=
it
->
second
.
find
(
step_info
.
model_version_
);
if
(
ver_it
==
it
->
second
.
end
())
{
std
::
shared_ptr
<
Model
>
model
=
nullptr
;
ensemble_status_
=
is_
->
GetModel
(
step_info
.
model_name_
,
step_info
.
model_version_
,
&
model
);
if
(
!
ensemble_status_
.
IsOk
())
{
break
;
}
it
->
second
.
emplace
(
std
::
make_pair
(
step_info
.
model_version_
,
model
));
}
}
// Prune ensemble first if not all outputs are requested
std
::
set
<
std
::
string
>
ignored_tensor
;
for
(
const
auto
&
ensemble_output
:
info_
->
ensemble_output_shape_
)
{
ignored_tensor
.
insert
(
ensemble_output
.
first
);
}
for
(
const
auto
&
requested_output
:
lrequest
->
ImmutableRequestedOutputs
())
{
ignored_tensor
.
erase
(
requested_output
);
}
if
(
ignored_tensor
.
empty
())
{
tensor_to_step_
=
&
(
info_
->
tensor_to_step_
);
}
else
{
pruned_tensor_to_step_
=
info_
->
tensor_to_step_
;
tensor_to_step_
=
&
pruned_tensor_to_step_
;
// Backward traversal
std
::
unordered_map
<
size_t
,
size_t
>
step_requested_output_count
;
while
(
!
ignored_tensor
.
empty
())
{
std
::
set
<
std
::
string
>
new_ignored_tensor
;
for
(
const
auto
&
output
:
ignored_tensor
)
{
auto
step_idx
=
info_
->
tensor_to_prev_step_
[
output
];
auto
&
step
=
info_
->
steps_
[
step_idx
];
auto
it
=
step_requested_output_count
.
find
(
step_idx
);
if
(
it
==
step_requested_output_count
.
end
())
{
auto
output_count
=
step
.
output_to_tensor_
.
size
();
it
=
step_requested_output_count
.
emplace
(
step_idx
,
output_count
).
first
;
}
// If none of the outputs of the step is requested,
// then the step can be pruned
if
(
--
it
->
second
==
0
)
{
for
(
const
auto
&
input
:
step
.
input_to_tensor_
)
{
auto
&
step_set
=
pruned_tensor_to_step_
[
input
.
second
];
step_set
.
erase
(
step_idx
);
// If all steps depend on a tensor are pruned,
// then the tensor can be ignored.
if
(
step_set
.
empty
())
{
new_ignored_tensor
.
insert
(
input
.
second
);
}
}
}
}
ignored_tensor
.
swap
(
new_ignored_tensor
);
}
}
for
(
const
auto
&
pair
:
*
tensor_to_step_
)
{
const
auto
&
requested_outputs
=
lrequest
->
ImmutableRequestedOutputs
();
// For requested outputs, add 1 to outgoing count as the ensemble itself
// isn't counted as step.
if
(
requested_outputs
.
find
(
pair
.
first
)
!=
requested_outputs
.
end
())
{
tensor_data_
.
emplace
(
pair
.
first
,
TensorData
(
pair
.
second
.
size
()
+
1
));
}
else
{
tensor_data_
.
emplace
(
pair
.
first
,
TensorData
(
pair
.
second
.
size
()));
}
}
if
(
ensemble_status_
.
IsOk
())
{
request_id_
=
lrequest
->
Id
();
correlation_id_
=
lrequest
->
CorrelationId
();
flags_
=
lrequest
->
Flags
();
priority_
=
lrequest
->
Priority
();
timeout_
=
lrequest
->
TimeoutMicroseconds
();
for
(
const
auto
&
pr
:
lrequest
->
ImmutableInputs
())
{
const
InferenceRequest
::
Input
*
input
=
pr
.
second
;
auto
it
=
tensor_data_
.
find
(
input
->
Name
());
if
(
it
!=
tensor_data_
.
end
())
{
auto
&
tensor_data
=
it
->
second
;
// Shape() represents reshaped value without batch dimension,
// thus need to fill it if necessary.
std
::
unique_ptr
<
InferenceRequest
::
Input
>
tensor
;
if
(
lrequest
->
BatchSize
()
!=
0
)
{
std
::
vector
<
int64_t
>
shape
{
lrequest
->
BatchSize
()};
shape
.
insert
(
shape
.
end
(),
input
->
Shape
().
begin
(),
input
->
Shape
().
end
());
tensor
.
reset
(
new
InferenceRequest
::
Input
(
input
->
Name
(),
input
->
DType
(),
shape
));
}
else
{
tensor
.
reset
(
new
InferenceRequest
::
Input
(
input
->
Name
(),
input
->
DType
(),
input
->
Shape
()));
}
tensor
->
SetData
(
input
->
Data
());
for
(
const
auto
&
host_policy_data
:
input
->
HostPolicyData
())
{
tensor
->
SetData
(
host_policy_data
.
first
,
host_policy_data
.
second
);
}
tensor_data
.
AddTensor
(
std
::
move
(
tensor
));
tensor_data
.
batch_size_
=
lrequest
->
BatchSize
();
}
else
{
ensemble_status_
=
Status
(
Status
::
Code
::
INVALID_ARG
,
lrequest
->
LogRequest
()
+
"unexpected input '"
+
input
->
Name
()
+
"' in request header that does not map to any ensemble inputs"
);
}
}
// Iterate the ensemble optional inputs and add empty tensor data entry
// if the input is not provided
for
(
const
auto
&
name
:
info_
->
optional_inputs_
)
{
auto
it
=
tensor_data_
.
find
(
name
);
if
((
it
!=
tensor_data_
.
end
())
&&
it
->
second
.
tensor_
.
empty
())
{
it
->
second
.
AddTensor
(
nullptr
);
it
->
second
.
batch_size_
=
lrequest
->
BatchSize
();
}
}
}
TRITONSERVER_ResponseAllocator
*
allocator
;
TRITONSERVER_Error
*
err
=
TRITONSERVER_ResponseAllocatorNew
(
&
allocator
,
ResponseAlloc
,
ResponseRelease
,
nullptr
/* start_fn */
);
if
(
err
==
nullptr
)
{
err
=
TRITONSERVER_ResponseAllocatorSetQueryFunction
(
allocator
,
OutputBufferQuery
);
}
if
(
err
!=
nullptr
)
{
ensemble_status_
=
Status
(
TritonCodeToStatusCode
(
TRITONSERVER_ErrorCode
(
err
)),
TRITONSERVER_ErrorMessage
(
err
));
TRITONSERVER_ErrorDelete
(
err
);
}
else
{
allocator_
.
reset
(
allocator
);
}
}
TRITONSERVER_Error
*
EnsembleContext
::
ResponseAlloc
(
TRITONSERVER_ResponseAllocator
*
allocator
,
const
char
*
tensor_name
,
size_t
byte_size
,
TRITONSERVER_MemoryType
preferred_memory_type
,
int64_t
preferred_memory_type_id
,
void
*
userp
,
void
**
buffer
,
void
**
buffer_userp
,
TRITONSERVER_MemoryType
*
allocated_memory_type
,
int64_t
*
allocated_memory_type_id
)
{
*
buffer
=
nullptr
;
*
buffer_userp
=
nullptr
;
auto
allocated_buffer
=
std
::
make_shared
<
AllocatedMemory
>
(
byte_size
,
preferred_memory_type
,
preferred_memory_type_id
);
auto
mutable_buffer
=
allocated_buffer
->
MutableBuffer
(
allocated_memory_type
,
allocated_memory_type_id
);
if
((
mutable_buffer
!=
nullptr
)
||
(
byte_size
==
0
))
{
if
(
byte_size
!=
0
)
{
*
buffer
=
static_cast
<
void
*>
(
mutable_buffer
);
auto
step
=
reinterpret_cast
<
Step
*>
(
userp
);
std
::
lock_guard
<
std
::
mutex
>
lk
(
step
->
output_mtx_
);
if
(
*
allocated_memory_type
==
TRITONSERVER_MEMORY_GPU
)
{
step
->
gpu_output_map_
[
*
allocated_memory_type_id
].
emplace
(
reinterpret_cast
<
uintptr_t
>
(
*
buffer
),
std
::
move
(
allocated_buffer
));
}
else
{
step
->
cpu_output_map_
.
emplace
(
reinterpret_cast
<
uintptr_t
>
(
*
buffer
),
std
::
move
(
allocated_buffer
));
}
}
LOG_VERBOSE
(
1
)
<<
"Internal response allocation: "
<<
tensor_name
<<
", size "
<<
byte_size
<<
", addr "
<<
*
buffer
<<
", memory type "
<<
*
allocated_memory_type
<<
", type id "
<<
*
allocated_memory_type_id
;
}
return
nullptr
;
// Success
}
TRITONSERVER_Error
*
EnsembleContext
::
ResponseRelease
(
TRITONSERVER_ResponseAllocator
*
allocator
,
void
*
buffer
,
void
*
buffer_userp
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
LOG_VERBOSE
(
1
)
<<
"Internal response release: "
<<
"size "
<<
byte_size
<<
", addr "
<<
buffer
;
// Don't do anything when releasing a buffer since ResponseAlloc
// passes the ownership of the data to ensemble context.
return
nullptr
;
// Success
}
TRITONSERVER_Error
*
EnsembleContext
::
OutputBufferQuery
(
TRITONSERVER_ResponseAllocator
*
allocator
,
void
*
userp
,
const
char
*
tensor_name
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
{
// Ensemble will always attempt to satisfy any output buffer request
return
nullptr
;
// Success
}
void
EnsembleContext
::
RequestComplete
(
TRITONSERVER_InferenceRequest
*
request
,
const
uint32_t
flags
,
void
*
userp
)
{
if
((
flags
&
TRITONSERVER_REQUEST_RELEASE_ALL
)
!=
0
)
{
LOG_TRITONSERVER_ERROR
(
TRITONSERVER_InferenceRequestDelete
(
request
),
"deleting ensemble inference request"
);
auto
request_tracker
=
reinterpret_cast
<
RequestTracker
*>
(
userp
);
if
(
request_tracker
->
DecrementCounter
())
{
delete
request_tracker
;
}
}
}
void
EnsembleContext
::
ResponseComplete
(
TRITONSERVER_InferenceResponse
*
response
,
const
uint32_t
flags
,
void
*
userp
)
{
auto
step_ptr
=
std
::
unique_ptr
<
Step
>
(
reinterpret_cast
<
Step
*>
(
userp
));
step_ptr
->
response_flags_
=
flags
;
if
(
response
!=
nullptr
)
{
auto
err
=
TRITONSERVER_InferenceResponseError
(
response
);
uint32_t
count
;
bool
parameter_override
=
false
;
InferenceRequest
::
SequenceId
correlation_id
{
0
};
uint32_t
flags
=
0
;
if
(
err
==
nullptr
)
{
err
=
TRITONSERVER_InferenceResponseParameterCount
(
response
,
&
count
);
if
(
err
==
nullptr
)
{
for
(
uint32_t
idx
=
0
;
idx
<
count
;
idx
++
)
{
const
char
*
name
;
TRITONSERVER_ParameterType
type
;
const
void
*
vvalue
;
err
=
TRITONSERVER_InferenceResponseParameter
(
response
,
idx
,
&
name
,
&
type
,
&
vvalue
);
if
(
err
==
nullptr
)
{
if
(
!
strcmp
(
name
,
"sequence_id"
))
{
switch
(
type
)
{
case
TRITONSERVER_PARAMETER_INT
:
correlation_id
=
InferenceRequest
::
SequenceId
(
*
reinterpret_cast
<
const
uint64_t
*>
(
vvalue
));
parameter_override
=
true
;
break
;
case
TRITONSERVER_PARAMETER_STRING
:
correlation_id
=
InferenceRequest
::
SequenceId
(
std
::
string
(
*
reinterpret_cast
<
const
char
*
const
*>
(
vvalue
)));
parameter_override
=
true
;
break
;
default:
err
=
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"expected parameter 'sequence_id' to be "
"TRITONSERVER_PARAMETER_INT or "
"TRITONSERVER_PARAMETER_STRING"
);
}
}
else
if
(
!
strcmp
(
name
,
"sequence_start"
))
{
if
(
type
!=
TRITONSERVER_PARAMETER_BOOL
)
{
err
=
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"expect paremeter 'sequence_start' to be "
"TRITONSERVER_PARAMETER_BOOL"
);
}
else
{
if
(
*
reinterpret_cast
<
const
bool
*>
(
vvalue
))
{
flags
|=
TRITONSERVER_REQUEST_FLAG_SEQUENCE_START
;
}
parameter_override
=
true
;
}
}
else
if
(
!
strcmp
(
name
,
"sequence_end"
))
{
if
(
type
!=
TRITONSERVER_PARAMETER_BOOL
)
{
err
=
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INVALID_ARG
,
"expect paremeter 'sequence_end' to be "
"TRITONSERVER_PARAMETER_BOOL"
);
}
else
{
if
(
*
reinterpret_cast
<
const
bool
*>
(
vvalue
))
{
flags
|=
TRITONSERVER_REQUEST_FLAG_SEQUENCE_END
;
}
parameter_override
=
true
;
}
}
}
}
}
}
if
(
err
==
nullptr
)
{
err
=
TRITONSERVER_InferenceResponseOutputCount
(
response
,
&
count
);
if
(
err
==
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
step_ptr
->
ctx_
->
mutex_
);
auto
&
output_to_tensor
=
step_ptr
->
ctx_
->
info_
->
steps_
[
step_ptr
->
step_idx_
]
.
output_to_tensor_
;
for
(
uint32_t
idx
=
0
;
idx
<
count
;
idx
++
)
{
const
char
*
name
;
TRITONSERVER_DataType
datatype
;
const
int64_t
*
shape
;
uint64_t
dim_count
;
const
void
*
base
;
size_t
byte_size
;
TRITONSERVER_MemoryType
memory_type
;
int64_t
memory_type_id
;
void
*
userp
;
err
=
TRITONSERVER_InferenceResponseOutput
(
response
,
idx
,
&
name
,
&
datatype
,
&
shape
,
&
dim_count
,
&
base
,
&
byte_size
,
&
memory_type
,
&
memory_type_id
,
&
userp
);
if
(
err
==
nullptr
)
{
auto
it
=
output_to_tensor
.
find
(
name
);
if
(
it
!=
output_to_tensor
.
end
())
{
std
::
unique_ptr
<
InferenceRequest
::
Input
>
tensor
(
new
InferenceRequest
::
Input
(
it
->
second
,
TritonToDataType
(
datatype
),
shape
,
dim_count
));
if
(
byte_size
!=
0
)
{
std
::
lock_guard
<
std
::
mutex
>
output_lk
(
step_ptr
->
output_mtx_
);
if
(
memory_type
==
TRITONSERVER_MEMORY_GPU
)
{
auto
&
gpu_output_map
=
step_ptr
->
gpu_output_map_
[
memory_type_id
];
auto
it
=
gpu_output_map
.
find
(
reinterpret_cast
<
uintptr_t
>
(
base
));
tensor
->
SetData
(
std
::
move
(
it
->
second
));
gpu_output_map
.
erase
(
it
);
}
else
{
auto
it
=
step_ptr
->
cpu_output_map_
.
find
(
reinterpret_cast
<
uintptr_t
>
(
base
));
tensor
->
SetData
(
std
::
move
(
it
->
second
));
step_ptr
->
cpu_output_map_
.
erase
(
it
);
}
}
auto
&
tensor_data
=
step_ptr
->
ctx_
->
tensor_data_
[
it
->
second
];
if
(
parameter_override
)
{
step_ptr
->
updated_tensors_
.
emplace
(
it
->
second
,
tensor_data
.
AddTensor
(
std
::
move
(
tensor
),
correlation_id
,
flags
));
}
else
{
step_ptr
->
updated_tensors_
.
emplace
(
it
->
second
,
tensor_data
.
AddTensor
(
std
::
move
(
tensor
),
step_ptr
->
correlation_id_
,
step_ptr
->
flags_
));
}
}
else
{
LOG_VERBOSE
(
1
)
<<
"in ensemble, an internal response header specified "
"output '"
<<
name
<<
"' that does not map to any ensemble tensors"
;
}
}
if
(
err
!=
nullptr
)
{
break
;
}
}
}
}
if
(
err
!=
nullptr
)
{
step_ptr
->
infer_status_
=
err
;
}
LOG_TRITONSERVER_ERROR
(
TRITONSERVER_InferenceResponseDelete
(
response
),
"deleting inference response"
);
}
EnsembleContext
::
Proceed
(
step_ptr
->
ctx_
,
step_ptr
);
// Expecting more responses
if
((
flags
&
TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
==
0
)
{
step_ptr
.
release
();
}
}
void
EnsembleContext
::
Proceed
(
const
std
::
shared_ptr
<
EnsembleContext
>&
context
,
const
std
::
unique_ptr
<
Step
>&
completed_step
)
{
StepList
ready_steps
;
Status
status
=
context
->
PrepareSteps
(
completed_step
,
&
ready_steps
);
if
(
status
.
IsOk
())
{
ScheduleSteps
(
context
,
std
::
move
(
ready_steps
));
}
}
Status
EnsembleContext
::
PrepareSteps
(
const
std
::
unique_ptr
<
Step
>&
completed_step
,
StepList
*
ready_steps
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
// Initialization error, ensemble status will be not ok since the beginning
if
(
completed_step
==
nullptr
&&
!
ensemble_status_
.
IsOk
())
{
ensemble_status_
=
FinishEnsemble
();
}
if
(
ensemble_status_
.
IsOk
())
{
StepList
res
;
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>
updated_tensors
;
ensemble_status_
=
UpdateEnsembleState
(
completed_step
,
&
updated_tensors
);
if
(
ensemble_status_
.
IsOk
())
{
ensemble_status_
=
GetNextSteps
(
updated_tensors
,
ready_steps
);
}
// Check and send ensemble response
if
((
!
ensemble_status_
.
IsOk
())
||
(
inflight_step_counter_
==
0
)
||
info_
->
is_decoupled_
)
{
std
::
unique_ptr
<
InferenceResponse
>
response
;
if
(
ensemble_status_
.
IsOk
())
{
ensemble_status_
=
CheckAndSetEnsembleOutput
(
updated_tensors
,
&
response
);
}
ensemble_status_
=
FinishEnsemble
(
std
::
move
(
response
));
}
}
return
ensemble_status_
;
}
}
Status
EnsembleContext
::
UpdateEnsembleState
(
const
std
::
unique_ptr
<
Step
>&
completed_step
,
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>*
updated_tensors
)
{
updated_tensors
->
clear
();
if
(
completed_step
==
nullptr
)
{
for
(
const
auto
&
tensor_data
:
tensor_data_
)
{
if
(
!
tensor_data
.
second
.
tensor_
.
empty
())
{
updated_tensors
->
emplace
(
tensor_data
.
first
,
0
);
}
}
}
else
{
if
(
completed_step
->
response_flags_
&
TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
{
inflight_step_counter_
--
;
}
RETURN_IF_TRITONSERVER_ERROR
(
completed_step
->
infer_status_
);
updated_tensors
->
swap
(
completed_step
->
updated_tensors_
);
}
return
Status
::
Success
;
}
Status
EnsembleContext
::
GetNextSteps
(
const
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>&
updated_tensors
,
StepList
*
steps
)
{
steps
->
clear
();
std
::
set
<
std
::
pair
<
size_t
,
IterationCount
>>
next_step_idx
;
// Get steps whose tensors used for input are set
for
(
const
auto
updated_tensor
:
updated_tensors
)
{
const
auto
&
step_idx
=
(
*
tensor_to_step_
)[
updated_tensor
.
first
];
for
(
const
auto
&
idx
:
step_idx
)
{
bool
ready
=
true
;
for
(
const
auto
&
input_pair
:
info_
->
steps_
[
idx
].
input_to_tensor_
)
{
auto
&
tensor
=
tensor_data_
[
input_pair
.
second
].
tensor_
;
if
(
tensor
.
empty
())
{
ready
=
false
;
break
;
}
else
{
// Check if other inputs have tensor with corresponding iteration
// count
if
(
tensor
.
find
(
updated_tensor
.
second
)
==
tensor
.
end
())
{
ready
=
false
;
break
;
}
}
}
if
(
ready
)
{
next_step_idx
.
emplace
(
idx
,
updated_tensor
.
second
);
}
}
}
for
(
const
auto
&
idx
:
next_step_idx
)
{
steps
->
emplace_back
();
RETURN_IF_ERROR
(
InitStep
(
idx
.
first
,
idx
.
second
,
&
(
steps
->
back
())));
}
inflight_step_counter_
+=
steps
->
size
();
return
Status
::
Success
;
}
Status
EnsembleContext
::
InitStep
(
const
size_t
step_idx
,
const
IterationCount
iteration_count
,
std
::
unique_ptr
<
Step
>*
step
)
{
const
auto
&
istep
=
info_
->
steps_
[
step_idx
];
auto
&
version_map
=
handles_
[
istep
.
model_name_
];
auto
&
model
=
version_map
[
istep
.
model_version_
];
const
bool
allow_batching
=
(
model
->
Config
().
max_batch_size
()
>
0
);
auto
irequest
=
std
::
unique_ptr
<
InferenceRequest
>
(
new
InferenceRequest
(
model
,
istep
.
model_version_
));
// Store the pointers to tensors used so that we can prune them afterward.
// Can't prune the tensor in the input loop below as it may be used by
// multiple inputs in the same step.
std
::
map
<
TensorData
*
,
size_t
*>
releasing_tensors
;
// Set inputs in request, prepare input map,
// and set overridden parameter if any.
auto
correlation_id
=
correlation_id_
;
auto
flags
=
flags_
;
bool
parameter_set
=
false
;
for
(
const
auto
&
pair
:
istep
.
input_to_tensor_
)
{
auto
&
tensor_data
=
tensor_data_
[
pair
.
second
];
auto
&
tensor
=
tensor_data
.
tensor_
[
iteration_count
];
// nullptr if and only if the tensor is optional ensemble input and
// not provided in the ensemble request. In such case, we don't add
// the input and expect the ensemble pipeline is configured correctly
// (the input to the inner model is also optional)
if
(
tensor
.
data_
!=
nullptr
)
{
// If the actual shape and config shape agree with each other without
// considering batch size, non-batch / batch conversion are not required.
const
inference
::
ModelInput
*
input_config
;
model
->
GetInput
(
pair
.
first
,
&
input_config
);
auto
shape
=
ReshapeTensorDims
(
input_config
->
dims
(),
allow_batching
,
tensor_data
.
batch_size_
,
tensor
.
data_
->
OriginalShape
());
InferenceRequest
::
Input
*
input
;
RETURN_IF_ERROR
(
irequest
->
AddOriginalInput
(
pair
.
first
,
tensor
.
data_
->
DType
(),
shape
,
&
input
));
RETURN_IF_ERROR
(
input
->
SetData
(
tensor
.
data_
->
Data
()));
for
(
const
auto
&
host_policy_data
:
tensor
.
data_
->
HostPolicyData
())
{
RETURN_IF_ERROR
(
input
->
SetData
(
host_policy_data
.
first
,
host_policy_data
.
second
));
}
}
releasing_tensors
.
emplace
(
&
tensor_data
,
&
tensor
.
remaining_reference_count_
);
if
(
tensor
.
parameter_override_
)
{
if
(
parameter_set
&&
((
correlation_id
!=
tensor
.
correlation_id_
)
||
(
flags
!=
tensor
.
flags_
)))
{
LOG_ERROR
<<
irequest
->
LogRequest
()
<<
"Different set of response parameters are set for '"
<<
istep
.
model_name_
<<
"'. Parameter correlation ID "
<<
correlation_id
<<
", flags "
<<
flags
<<
" is used."
;
continue
;
}
correlation_id
=
tensor
.
correlation_id_
;
flags
=
tensor
.
flags_
;
parameter_set
=
true
;
}
}
// Prune the tensor if it is not needed by other steps
for
(
auto
&
releasing_pair
:
releasing_tensors
)
{
if
((
--
(
*
releasing_pair
.
second
))
==
0
)
{
releasing_pair
.
first
->
tensor_
.
erase
(
iteration_count
);
}
}
// Set requested outputs in request header
for
(
const
auto
&
pair
:
istep
.
output_to_tensor_
)
{
irequest
->
AddOriginalRequestedOutput
(
pair
.
first
);
}
step
->
reset
(
new
Step
(
step_idx
,
correlation_id
,
flags
));
irequest
->
SetId
(
request_id_
);
irequest
->
SetCorrelationId
(
correlation_id
);
irequest
->
SetFlags
(
flags
);
irequest
->
SetPriority
(
priority_
);
irequest
->
SetTimeoutMicroseconds
(
timeout_
);
#ifdef TRITON_ENABLE_STATS
irequest
->
SetSecondaryStatsAggregator
(
&
request_tracker_
->
ContextStatsAggregator
());
#endif
irequest
->
SetResponseCallback
(
reinterpret_cast
<
ResponseAllocator
*>
(
allocator_
.
get
()),
step
->
get
(),
ResponseComplete
,
step
->
get
());
irequest
->
SetReleaseCallback
(
RequestComplete
,
request_tracker_
);
RETURN_IF_ERROR
(
irequest
->
PrepareForInference
());
#ifdef TRITON_ENABLE_TRACING
auto
&
parent_trace
=
request_tracker_
->
Request
()
->
Trace
();
if
(
parent_trace
!=
nullptr
)
{
irequest
->
SetTrace
(
parent_trace
->
SpawnChildTrace
());
irequest
->
Trace
()
->
SetModelName
(
irequest
->
ModelName
());
irequest
->
Trace
()
->
SetModelVersion
(
irequest
->
ActualModelVersion
());
}
#endif
// Record the batch size of output in advance as
// there is no other way to access it later on.
for
(
const
auto
&
pair
:
istep
.
output_to_tensor_
)
{
auto
&
output_data_
=
tensor_data_
[
pair
.
second
];
output_data_
.
batch_size_
=
irequest
->
BatchSize
();
}
(
*
step
)
->
request_
=
std
::
move
(
irequest
);
return
Status
::
Success
;
}
std
::
vector
<
int64_t
>
EnsembleContext
::
ReshapeTensorDims
(
const
triton
::
common
::
DimsList
&
config_dims
,
const
bool
config_allow_batching
,
const
size_t
tensor_batch_size
,
const
std
::
vector
<
int64_t
>&
dims
)
{
bool
reshaped
=
false
;
std
::
vector
<
int64_t
>
res
;
// Only attempt to reshape if one setting is batchable while the other is not,
// the case of two mismatched batchable shapes is not considered.
// If the actual shape and config shape agree with each other without
// considering batch size, non-batch / batch conversion are not required.
if
(
config_allow_batching
!=
(
tensor_batch_size
!=
0
))
{
// expect batching but the tensor is generated from nobatching model
if
(
config_allow_batching
)
{
if
(
triton
::
common
::
CompareDimsWithWildcard
(
config_dims
,
dims
))
{
// If 'dims' already matches 'config_dims', prepend with batch size 1
res
.
push_back
(
1
);
res
.
insert
(
res
.
end
(),
dims
.
begin
(),
dims
.
end
());
reshaped
=
true
;
}
// Otherwise, assuming the tensor is already in the batch expected
// by the model and do nothing
}
else
{
// Check if the batched tensor can be sent to the non-batching
// model as one tensor. If not, strip the batch dimension if
// it is batch size 1
if
(
!
triton
::
common
::
CompareDimsWithWildcard
(
config_dims
,
dims
)
&&
(
tensor_batch_size
==
1
))
{
res
.
assign
(
dims
.
begin
()
+
1
,
dims
.
end
());
reshaped
=
true
;
}
}
}
if
(
!
reshaped
)
{
res
=
dims
;
}
return
res
;
}
Status
EnsembleContext
::
FinishEnsemble
(
std
::
unique_ptr
<
InferenceResponse
>&&
response
)
{
// Do nothing if the ensemble is finished
if
(
request_tracker_
==
nullptr
)
{
return
ensemble_status_
;
}
// Add ensemble name to make error message more trackable
if
(
!
ensemble_status_
.
IsOk
())
{
ensemble_status_
=
Status
(
ensemble_status_
.
StatusCode
(),
"in ensemble '"
+
info_
->
ensemble_name_
+
"', "
+
ensemble_status_
.
Message
());
}
if
(
ensemble_status_
.
IsOk
())
{
if
(
info_
->
is_decoupled_
)
{
if
(
response
!=
nullptr
)
{
InferenceResponse
::
Send
(
std
::
move
(
response
),
0
/* flags */
);
}
if
(
inflight_step_counter_
!=
0
)
{
return
ensemble_status_
;
}
request_tracker_
->
Request
()
->
ResponseFactory
()
->
SendFlags
(
TRITONSERVER_RESPONSE_COMPLETE_FINAL
);
}
else
{
InferenceResponse
::
Send
(
std
::
move
(
response
),
TRITONSERVER_RESPONSE_COMPLETE_FINAL
);
}
}
else
{
if
(
response
!=
nullptr
)
{
InferenceResponse
::
SendWithStatus
(
std
::
move
(
response
),
TRITONSERVER_RESPONSE_COMPLETE_FINAL
,
ensemble_status_
);
}
else
{
InferenceRequest
::
RespondIfError
(
request_tracker_
->
Request
(),
ensemble_status_
);
}
}
// Reach here when the ensemble execution comes to the end, 'ensemble_status_'
// at this point is representative.
request_tracker_
->
SetStatus
(
ensemble_status_
);
if
(
request_tracker_
->
DecrementCounter
())
{
delete
request_tracker_
;
}
request_tracker_
=
nullptr
;
return
ensemble_status_
;
}
Status
EnsembleContext
::
CheckAndSetEnsembleOutput
(
const
std
::
set
<
std
::
pair
<
std
::
string
,
IterationCount
>>&
updated_tensors
,
std
::
unique_ptr
<
InferenceResponse
>*
response
)
{
IterationCount
iteration_count
=
0
;
// Check if updated tensor is one of the ensemble output and if all outputs
// have tensor of the same iteration count
bool
ready
=
false
;
auto
&
lrequest
=
request_tracker_
->
Request
();
const
auto
&
requested_outputs
=
lrequest
->
ImmutableRequestedOutputs
();
for
(
const
auto
updated_tensor
:
updated_tensors
)
{
if
(
requested_outputs
.
find
(
updated_tensor
.
first
)
==
requested_outputs
.
end
())
{
continue
;
}
ready
=
true
;
iteration_count
=
updated_tensor
.
second
;
for
(
const
auto
&
output
:
requested_outputs
)
{
auto
&
tensor
=
tensor_data_
[
output
].
tensor_
;
if
(
tensor
.
empty
())
{
ready
=
false
;
break
;
}
else
{
// Check if other outputs have tensor with corresponding iteration count
if
(
tensor
.
find
(
iteration_count
)
==
tensor
.
end
())
{
ready
=
false
;
break
;
}
}
}
}
if
(
!
ready
)
{
if
(
info_
->
is_decoupled_
)
{
return
Status
::
Success
;
}
return
Status
(
Status
::
Code
::
INVALID_ARG
,
lrequest
->
LogRequest
()
+
"unexpected deadlock, at least one output is not set while no more "
"ensemble steps can be made"
);
}
RETURN_IF_ERROR
(
lrequest
->
ResponseFactory
()
->
CreateResponse
(
response
));
bool
cuda_async_copy
=
false
;
std
::
map
<
TensorData
*
,
size_t
*>
releasing_tensors
;
for
(
const
auto
&
output_pair
:
info_
->
ensemble_output_shape_
)
{
if
(
requested_outputs
.
find
(
output_pair
.
first
)
==
requested_outputs
.
end
())
{
continue
;
}
// Check if output is ready
auto
&
tensor_data
=
tensor_data_
[
output_pair
.
first
];
auto
&
tensor
=
tensor_data
.
tensor_
[
iteration_count
];
auto
shape
=
ReshapeTensorDims
(
output_pair
.
second
,
(
lrequest
->
BatchSize
()
!=
0
),
tensor_data
.
batch_size_
,
tensor
.
data_
->
OriginalShape
());
InferenceResponse
::
Output
*
output
;
RETURN_IF_ERROR
((
*
response
)
->
AddOutput
(
output_pair
.
first
,
tensor
.
data_
->
DType
(),
shape
,
&
output
));
// Use the memory type of the memory block as preferred memory type
TRITONSERVER_MemoryType
dst_memory_type
;
int64_t
dst_memory_type_id
;
size_t
content_size
;
tensor
.
data_
->
Data
()
->
BufferAt
(
0
,
&
content_size
,
&
dst_memory_type
,
&
dst_memory_type_id
);
void
*
buffer
;
RETURN_IF_ERROR
(
output
->
AllocateDataBuffer
(
&
buffer
,
content_size
,
&
dst_memory_type
,
&
dst_memory_type_id
));
// Done with this output if 'expected_byte_size' is 0
if
(
content_size
==
0
)
{
continue
;
}
else
if
(
buffer
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to allocate buffer for output '"
+
output_pair
.
first
+
"'"
);
}
size_t
content_offset
=
0
;
size_t
content_idx
=
0
;
TRITONSERVER_MemoryType
src_memory_type
;
int64_t
src_memory_type_id
;
const
char
*
content
=
tensor
.
data_
->
Data
()
->
BufferAt
(
content_idx
,
&
content_size
,
&
src_memory_type
,
&
src_memory_type_id
);
bool
cuda_used
=
false
;
while
(
content
!=
nullptr
)
{
RETURN_IF_ERROR
(
CopyBuffer
(
output_pair
.
first
,
src_memory_type
,
src_memory_type_id
,
dst_memory_type
,
dst_memory_type_id
,
content_size
,
content
,
((
char
*
)
buffer
)
+
content_offset
,
stream_
,
&
cuda_used
));
cuda_async_copy
|=
cuda_used
;
content_offset
+=
content_size
;
content_idx
++
;
content
=
tensor
.
data_
->
Data
()
->
BufferAt
(
content_idx
,
&
content_size
,
&
src_memory_type
,
&
src_memory_type_id
);
}
releasing_tensors
.
emplace
(
&
tensor_data
,
&
tensor
.
remaining_reference_count_
);
if
(
tensor
.
parameter_override_
)
{
switch
(
lrequest
->
CorrelationId
().
Type
())
{
case
InferenceRequest
::
SequenceId
::
DataType
::
STRING
:
(
*
response
)
->
AddParameter
(
"sequence_id"
,
tensor
.
correlation_id_
.
StringValue
().
c_str
());
break
;
case
InferenceRequest
::
SequenceId
::
DataType
::
UINT64
:
(
*
response
)
->
AddParameter
(
"sequence_id"
,
(
int64_t
)
tensor
.
correlation_id_
.
UnsignedIntValue
());
break
;
default:
(
*
response
)
->
AddParameter
(
"sequence_id"
,
(
int64_t
)
tensor
.
correlation_id_
.
UnsignedIntValue
());
break
;
}
(
*
response
)
->
AddParameter
(
"sequence_start"
,
(
tensor
.
flags_
&
TRITONSERVER_REQUEST_FLAG_SEQUENCE_START
)
!=
0
);
(
*
response
)
->
AddParameter
(
"sequence_end"
,
(
tensor
.
flags_
&
TRITONSERVER_REQUEST_FLAG_SEQUENCE_END
)
!=
0
);
}
}
if
(
cuda_async_copy
)
{
#ifdef TRITON_ENABLE_GPU
cudaStreamSynchronize
(
stream_
);
#else
return
Status
(
Status
::
Code
::
INTERNAL
,
"unexpected CUDA copy flag set while GPU is not supported"
);
#endif // TRITON_ENABLE_GPU
}
// Prune the tensor if it is not needed by other steps
for
(
auto
&
releasing_pair
:
releasing_tensors
)
{
if
((
--
(
*
releasing_pair
.
second
))
==
0
)
{
releasing_pair
.
first
->
tensor_
.
erase
(
iteration_count
);
}
}
return
Status
::
Success
;
}
void
EnsembleContext
::
ScheduleSteps
(
const
std
::
shared_ptr
<
EnsembleContext
>&
context
,
StepList
&&
steps
)
{
for
(
auto
&
step
:
steps
)
{
step
->
ctx_
=
context
;
bool
should_schedule
=
false
;
// Must release lock before InferAsync to avoid deadlock, as the same thread
// will be calling request/response callbacks on cache hits, which will
// attempt to acquire the lock already held
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
context
->
mutex_
);
// Need to check the ensemble_status_ to ensure the FinishEnsemble()
// is called only once.
if
(
context
->
ensemble_status_
.
IsOk
())
{
context
->
request_tracker_
->
IncrementCounter
();
should_schedule
=
true
;
}
}
if
(
should_schedule
)
{
// On a successful call to InferAsync(), the step will be released by
// the response callback. When the response callback is invoked, the
// step must not own (and release) the request as the request should be
// transferred and managed by Triton core. In the case of cache hit, the
// request hasn't been transferred and can cause double-free, so moving
// the request ownership out of step here to avoid that
std
::
unique_ptr
<
InferenceRequest
>
request
=
std
::
move
(
step
->
request_
);
auto
step_status
=
context
->
is_
->
InferAsync
(
request
);
if
(
!
step_status
.
IsOk
())
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
context
->
mutex_
);
context
->
ensemble_status_
=
step_status
;
// The request is not sent to server properly, shouldn't expect its
// release function get called.
context
->
request_tracker_
->
DecrementCounter
();
context
->
ensemble_status_
=
context
->
FinishEnsemble
();
break
;
}
}
step
.
release
();
}
}
}
// namespace
Status
EnsembleScheduler
::
Create
(
InferenceStatsAggregator
*
const
stats_aggregator
,
InferenceServer
*
const
server
,
const
inference
::
ModelConfig
&
config
,
std
::
unique_ptr
<
Scheduler
>*
scheduler
)
{
scheduler
->
reset
(
new
EnsembleScheduler
(
stats_aggregator
,
server
,
config
));
return
Status
::
Success
;
}
Status
EnsembleScheduler
::
Enqueue
(
std
::
unique_ptr
<
InferenceRequest
>&
request
)
{
// Queue timer starts at the beginning of the queueing and
// scheduling process
request
->
CaptureQueueStartNs
();
INFER_TRACE_ACTIVITY
(
request
->
Trace
(),
TRITONSERVER_TRACE_QUEUE_START
,
request
->
QueueStartNs
());
#ifdef TRITON_ENABLE_TRACING
request
->
TraceInputTensors
(
TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT
,
"EnsembleScheduler Enqueue"
);
#endif // TRITON_ENABLE_TRACING
// Add additional callback to keep track of in-flight count
++
inflight_count_
;
request
->
AddInternalReleaseCallback
([
this
]()
{
--
inflight_count_
;
});
std
::
shared_ptr
<
EnsembleContext
>
context
(
new
EnsembleContext
(
metric_reporter_
.
get
(),
stats_aggregator_
,
is_
,
info_
.
get
(),
request
,
stream_
));
EnsembleContext
::
Proceed
(
context
);
return
Status
::
Success
;
}
EnsembleScheduler
::
EnsembleScheduler
(
InferenceStatsAggregator
*
const
stats_aggregator
,
InferenceServer
*
const
server
,
const
inference
::
ModelConfig
&
config
)
:
stats_aggregator_
(
stats_aggregator
),
is_
(
server
),
stream_
(
nullptr
),
inflight_count_
(
0
)
{
#ifdef TRITON_ENABLE_GPU
// create CUDA stream
auto
cuerr
=
cudaStreamCreate
(
&
stream_
);
if
(
cuerr
!=
cudaSuccess
)
{
stream_
=
nullptr
;
LOG_ERROR
<<
"unable to create stream for "
<<
config
.
name
()
<<
": "
<<
cudaGetErrorString
(
cuerr
);
}
#endif // TRITON_ENABLE_GPU
#ifdef TRITON_ENABLE_METRICS
if
(
Metrics
::
Enabled
())
{
MetricModelReporter
::
Create
(
config
.
name
(),
1
,
METRIC_REPORTER_ID_CPU
,
config
.
metric_tags
(),
&
metric_reporter_
);
}
#endif // TRITON_ENABLE_METRICS
// Set 'info_' based on 'config'
info_
.
reset
(
new
EnsembleInfo
());
info_
->
ensemble_name_
=
config
.
name
();
// This config field is filled internally for ensemble models
info_
->
is_decoupled_
=
config
.
model_transaction_policy
().
decoupled
();
for
(
const
auto
&
input
:
config
.
input
())
{
info_
->
tensor_to_step_
.
emplace
(
input
.
name
(),
std
::
set
<
size_t
>
());
if
(
input
.
optional
())
{
info_
->
optional_inputs_
.
emplace
(
input
.
name
());
}
}
for
(
const
auto
&
output
:
config
.
output
())
{
info_
->
tensor_to_step_
.
emplace
(
output
.
name
(),
std
::
set
<
size_t
>
());
if
(
output
.
has_reshape
())
{
info_
->
ensemble_output_shape_
[
output
.
name
()]
=
output
.
reshape
().
shape
();
}
else
{
info_
->
ensemble_output_shape_
[
output
.
name
()]
=
output
.
dims
();
}
}
for
(
const
auto
&
element
:
config
.
ensemble_scheduling
().
step
())
{
size_t
step_idx
=
info_
->
steps_
.
size
();
info_
->
steps_
.
emplace_back
(
element
.
model_name
(),
element
.
model_version
());
for
(
const
auto
&
pair
:
element
.
input_map
())
{
auto
it
=
info_
->
tensor_to_step_
.
find
(
pair
.
second
);
if
(
it
==
info_
->
tensor_to_step_
.
end
())
{
it
=
info_
->
tensor_to_step_
.
emplace
(
pair
.
second
,
std
::
set
<
size_t
>
())
.
first
;
}
it
->
second
.
insert
(
step_idx
);
info_
->
steps_
[
step_idx
].
input_to_tensor_
.
emplace
(
std
::
make_pair
(
pair
.
first
,
pair
.
second
));
}
for
(
const
auto
&
pair
:
element
.
output_map
())
{
auto
it
=
info_
->
tensor_to_step_
.
find
(
pair
.
second
);
if
(
it
==
info_
->
tensor_to_step_
.
end
())
{
it
=
info_
->
tensor_to_step_
.
emplace
(
pair
.
second
,
std
::
set
<
size_t
>
())
.
first
;
}
info_
->
steps_
[
step_idx
].
output_to_tensor_
.
emplace
(
std
::
make_pair
(
pair
.
first
,
pair
.
second
));
info_
->
tensor_to_prev_step_
.
emplace
(
pair
.
second
,
step_idx
);
}
}
}
EnsembleScheduler
::~
EnsembleScheduler
()
{
#ifdef TRITON_ENABLE_GPU
if
(
stream_
!=
nullptr
)
{
cudaError_t
err
=
cudaStreamDestroy
(
stream_
);
if
(
err
!=
cudaSuccess
)
{
LOG_ERROR
<<
"Failed to destroy cuda stream: "
<<
cudaGetErrorString
(
err
);
}
}
#endif // TRITON_ENABLE_GPU
}
}}
// namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
3rdparty/core-r22.12/src/ensemble_scheduler.h
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_ENSEMBLE
#include <memory>
#include "metric_model_reporter.h"
#include "model_config.pb.h"
#include "model_config_utils.h"
#include "scheduler.h"
#include "status.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace
triton
{
namespace
core
{
#ifndef TRITON_ENABLE_GPU
using
cudaStream_t
=
void
*
;
#endif // TRITON_ENABLE_GPU
class
InferenceServer
;
struct
EnsembleInfo
{
struct
StepInfo
{
StepInfo
(
const
std
::
string
&
model_name
,
const
int64_t
model_version
)
:
model_name_
(
model_name
),
model_version_
(
model_version
)
{
}
std
::
string
model_name_
;
int64_t
model_version_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
input_to_tensor_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_to_tensor_
;
};
std
::
string
ensemble_name_
;
bool
is_decoupled_
;
// the ensemble output (re)shape expected by the ensemble
std
::
unordered_map
<
std
::
string
,
triton
::
common
::
DimsList
>
ensemble_output_shape_
;
// Inputs that is marked optional for the ensemble
std
::
set
<
std
::
string
>
optional_inputs_
;
std
::
vector
<
StepInfo
>
steps_
;
// Only include a step if the ensemble tensor is used as input in that step
std
::
unordered_map
<
std
::
string
,
std
::
set
<
size_t
>>
tensor_to_step_
;
// backward path, ensemble tensor to the step that provides its data
std
::
unordered_map
<
std
::
string
,
size_t
>
tensor_to_prev_step_
;
};
// Scheduler that implements ensemble scheduling.
class
EnsembleScheduler
:
public
Scheduler
{
public:
// Create a scheduler to process ensemble requests and
// to dispatch requests to models in ensemble internally.
static
Status
Create
(
InferenceStatsAggregator
*
const
stats_aggregator
,
InferenceServer
*
const
server
,
const
inference
::
ModelConfig
&
config
,
std
::
unique_ptr
<
Scheduler
>*
scheduler
);
~
EnsembleScheduler
();
// \see Scheduler::Enqueue()
Status
Enqueue
(
std
::
unique_ptr
<
InferenceRequest
>&
request
)
override
;
// \see Scheduler::InflightInferenceCount()
size_t
InflightInferenceCount
()
override
{
return
inflight_count_
;
}
// \see Scheduler::Stop()
void
Stop
()
override
{}
private:
EnsembleScheduler
(
InferenceStatsAggregator
*
const
stats_aggregator
,
InferenceServer
*
const
server
,
const
inference
::
ModelConfig
&
config
);
std
::
shared_ptr
<
MetricModelReporter
>
metric_reporter_
;
InferenceStatsAggregator
*
const
stats_aggregator_
;
InferenceServer
*
const
is_
;
// Ensemble information that is built from model config
std
::
unique_ptr
<
EnsembleInfo
>
info_
;
// The stream used for data transfer.
cudaStream_t
stream_
;
std
::
atomic
<
size_t
>
inflight_count_
;
};
}}
// namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
3rdparty/core-r22.12/src/ensemble_utils.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_utils.h"
#include <set>
#include "constants.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
namespace
{
/// A basic unit in ensemble graph that records the data type and shape
/// of the ensemble tensor and which model they are inferred from.
struct
TensorNode
{
TensorNode
(
const
std
::
string
&
model_name
,
const
bool
batching
,
const
inference
::
DataType
&
type
,
const
triton
::
common
::
DimsList
&
dims
)
:
model_name_
(
model_name
),
type_
(
type
),
dims_
(
dims
),
is_decoupled_
(
false
),
decouple_label_
(
0
),
visited_
(
false
)
{
// Expand dims to full shape, which includes batch dimension if exist
if
(
batching
)
{
full_dims_
.
Add
(
-
1
);
}
full_dims_
.
MergeFrom
(
dims_
);
}
// Constructor for symbolic nodes
TensorNode
(
const
std
::
string
&
model_name
)
:
model_name_
(
model_name
),
is_decoupled_
(
false
),
decouple_label_
(
0
),
visited_
(
false
)
{
}
std
::
string
model_name_
;
inference
::
DataType
type_
;
triton
::
common
::
DimsList
dims_
;
triton
::
common
::
DimsList
full_dims_
;
bool
is_decoupled_
;
size_t
decouple_label_
;
bool
visited_
;
std
::
vector
<
TensorNode
*>
prev_nodes_
;
std
::
vector
<
TensorNode
*>
next_nodes_
;
// A symbolic node to keep track of the decouple label of nodes that
// are outputs of the same step.
std
::
shared_ptr
<
TensorNode
>
sibling_node_
;
};
/// Validate if the data type and the shape of two TensorNode object are
/// consistent.
/// \param lhs One of the TensorNode object to be validated.
/// \param rhs Another TensorNode object to be validated.
/// \param message Extra message included in the front of error message
/// if error status is non-OK.
/// \return The error status. A non-OK status indicates the TensorNode objects
/// are not consistent.
Status
ValidateTensorConsistency
(
const
TensorNode
&
lhs
,
const
TensorNode
&
rhs
,
const
std
::
string
&
message
)
{
if
(
lhs
.
type_
!=
rhs
.
type_
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message
+
"inconsistent data type: "
+
inference
::
DataType_Name
(
lhs
.
type_
)
+
" is inferred from model "
+
lhs
.
model_name_
+
" while "
+
inference
::
DataType_Name
(
rhs
.
type_
)
+
" is inferred from model "
+
rhs
.
model_name_
);
}
// Shapes must match or either one uses variable size shape, if one uses
// variable size shape, shape consistency will be checked at runtime.
// If dims mismatch, compare agian with full dims in case the tensor is
// used for both non-batching model and batching model. In that case, it
// is acceptable if non-batching model shape is [-1, d_0, d_1, ..., d_n]
// while the batching model shape is [d_0, d_1, ..., d_n].
if
(
!
triton
::
common
::
CompareDimsWithWildcard
(
lhs
.
dims_
,
rhs
.
dims_
)
&&
!
triton
::
common
::
CompareDimsWithWildcard
(
lhs
.
full_dims_
,
rhs
.
full_dims_
))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
message
+
"inconsistent shape: "
+
triton
::
common
::
DimsListToString
(
lhs
.
full_dims_
)
+
" is inferred from model "
+
lhs
.
model_name_
+
" while "
+
triton
::
common
::
DimsListToString
(
rhs
.
full_dims_
)
+
" is inferred from model "
+
rhs
.
model_name_
);
}
return
Status
::
Success
;
}
Status
ValidateTensorMapping
(
const
std
::
string
&
ensemble
,
const
inference
::
ModelEnsembling
::
Step
&
step
,
const
inference
::
ModelConfig
&
model_config
,
std
::
unordered_map
<
std
::
string
,
TensorNode
>*
ensemble_tensors
)
{
const
bool
batching
=
(
model_config
.
max_batch_size
()
>
0
);
// Check all inputs are mapped and no mapping to invalid inputs
std
::
set
<
std
::
string
>
input_names
;
for
(
const
auto
&
model_input
:
model_config
.
input
())
{
input_names
.
insert
(
model_input
.
name
());
}
for
(
const
auto
&
input_map
:
step
.
input_map
())
{
if
(
input_names
.
find
(
input_map
.
first
)
==
input_names
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"in ensemble "
+
ensemble
+
", ensemble tensor "
+
input_map
.
second
+
" is mapping to non-existing input "
+
input_map
.
first
+
" in model "
+
step
.
model_name
());
}
}
for
(
const
auto
&
model_input
:
model_config
.
input
())
{
size_t
mapped_cnt
=
0
;
for
(
const
auto
&
input_map
:
step
.
input_map
())
{
if
(
model_input
.
name
()
==
input_map
.
first
)
{
TensorNode
model_tensor
(
step
.
model_name
(),
batching
,
model_input
.
data_type
(),
model_input
.
dims
());
auto
it
=
ensemble_tensors
->
find
(
input_map
.
second
);
if
(
it
!=
ensemble_tensors
->
end
())
{
RETURN_IF_ERROR
(
ValidateTensorConsistency
(
it
->
second
,
model_tensor
,
"in ensemble "
+
ensemble
+
", ensemble tensor "
+
input_map
.
second
+
": "
));
}
else
{
ensemble_tensors
->
emplace
(
std
::
make_pair
(
input_map
.
second
,
model_tensor
));
}
mapped_cnt
++
;
}
}
if
(
mapped_cnt
==
0
)
{
// Allow the input to be excluded from ensemble if it is optional
if
(
model_input
.
optional
())
{
continue
;
}
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"in ensemble "
+
ensemble
+
", input "
+
model_input
.
name
()
+
" in model "
+
model_config
.
name
()
+
" is not mapped to any ensemble tensors"
);
}
else
if
(
mapped_cnt
>
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"in ensemble "
+
ensemble
+
", input "
+
model_input
.
name
()
+
" in model "
+
model_config
.
name
()
+
" is mapped to multiple ensemble tensors"
);
}
}
// Check no multiple mappings to same ensemble tensor
// and no mapping from invalid outputs
std
::
set
<
std
::
string
>
output_names
;
for
(
const
auto
&
model_output
:
model_config
.
output
())
{
output_names
.
insert
(
model_output
.
name
());
}
for
(
const
auto
&
output_map
:
step
.
output_map
())
{
if
(
output_names
.
find
(
output_map
.
first
)
==
output_names
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"in ensemble "
+
ensemble
+
", ensemble tensor "
+
output_map
.
second
+
" is mapped from non-existing output "
+
output_map
.
first
+
" in model "
+
step
.
model_name
());
}
}
std
::
shared_ptr
<
TensorNode
>
sibling_node
(
new
TensorNode
(
step
.
model_name
()));
for
(
const
auto
&
output_map
:
step
.
output_map
())
{
size_t
mapped_cnt
=
0
;
for
(
const
auto
&
model_output
:
model_config
.
output
())
{
if
(
model_output
.
name
()
==
output_map
.
first
)
{
TensorNode
model_tensor
(
step
.
model_name
(),
batching
,
model_output
.
data_type
(),
model_output
.
dims
());
auto
it
=
ensemble_tensors
->
find
(
output_map
.
second
);
if
(
it
!=
ensemble_tensors
->
end
())
{
RETURN_IF_ERROR
(
ValidateTensorConsistency
(
it
->
second
,
model_tensor
,
"in ensemble "
+
ensemble
+
", ensemble tensor "
+
output_map
.
second
+
": "
));
}
else
{
it
=
ensemble_tensors
->
emplace
(
std
::
make_pair
(
output_map
.
second
,
model_tensor
))
.
first
;
}
it
->
second
.
sibling_node_
=
sibling_node
;
mapped_cnt
++
;
}
}
if
(
mapped_cnt
>
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"in ensemble "
+
ensemble
+
", multiple outputs in model "
+
model_config
.
name
()
+
" are mapped to the same ensemble tensor "
+
output_map
.
second
);
}
}
// link ensemble tensors
bool
is_decoupled
=
model_config
.
model_transaction_policy
().
decoupled
();
for
(
const
auto
&
output_map
:
step
.
output_map
())
{
auto
&
node
=
ensemble_tensors
->
find
(
output_map
.
second
)
->
second
;
node
.
is_decoupled_
=
is_decoupled
;
for
(
const
auto
&
input_map
:
step
.
input_map
())
{
auto
&
prev_node
=
ensemble_tensors
->
find
(
input_map
.
second
)
->
second
;
node
.
prev_nodes_
.
push_back
(
&
prev_node
);
prev_node
.
next_nodes_
.
push_back
(
&
node
);
}
}
return
Status
::
Success
;
}
}
// namespace
Status
ValidateEnsembleConfig
(
ModelRepositoryManager
*
model_repository_manager
,
ModelRepositoryManager
::
DependencyNode
*
ensemble
)
{
const
auto
&
ensemble_config
=
ensemble
->
model_config_
;
if
(
!
ensemble_config
.
has_ensemble_scheduling
())
{
return
Status
::
Success
;
}
const
auto
&
ensemble_name
=
ensemble
->
model_name_
;
const
bool
batching
=
(
ensemble_config
.
max_batch_size
()
>
0
);
std
::
unordered_map
<
std
::
string
,
TensorNode
>
ensemble_tensors
;
for
(
const
auto
&
input
:
ensemble_config
.
input
())
{
const
auto
&
dims
=
input
.
has_reshape
()
?
input
.
reshape
().
shape
()
:
input
.
dims
();
TensorNode
input_node
(
ensemble_name
,
batching
,
input
.
data_type
(),
dims
);
ensemble_tensors
.
emplace
(
std
::
make_pair
(
input
.
name
(),
input_node
));
}
TensorNode
sink_node
(
ensemble_name
);
for
(
const
auto
&
output
:
ensemble_config
.
output
())
{
const
auto
&
dims
=
output
.
has_reshape
()
?
output
.
reshape
().
shape
()
:
output
.
dims
();
TensorNode
output_node
(
ensemble_name
,
batching
,
output
.
data_type
(),
dims
);
auto
it
=
ensemble_tensors
.
emplace
(
std
::
make_pair
(
output
.
name
(),
output_node
))
.
first
;
sink_node
.
prev_nodes_
.
emplace_back
(
&
(
it
->
second
));
it
->
second
.
next_nodes_
.
emplace_back
(
&
sink_node
);
}
for
(
const
auto
&
step
:
ensemble_config
.
ensemble_scheduling
().
step
())
{
const
auto
&
model_name
=
step
.
model_name
();
inference
::
ModelConfig
model_config
;
for
(
auto
&
node
:
ensemble
->
upstreams_
)
{
if
(
model_name
==
node
.
first
->
model_name_
)
{
// Obtain completed config from model instance
std
::
shared_ptr
<
Model
>
model
;
RETURN_IF_ERROR
(
model_repository_manager
->
GetModel
(
model_name
,
-
1
,
&
model
));
model_config
=
model
->
Config
();
break
;
}
}
// batchable ensemble can include non-batchable models as long as
// the expanded shapes are consistent
if
((
model_config
.
max_batch_size
()
!=
0
)
&&
(
model_config
.
max_batch_size
()
<
ensemble_config
.
max_batch_size
()))
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"ensemble "
+
ensemble_name
+
" allows maximum batch size "
+
std
::
to_string
(
ensemble_config
.
max_batch_size
())
+
", but it contains model "
+
model_name
+
" which only allows maximum batch size to be "
+
std
::
to_string
(
model_config
.
max_batch_size
()));
}
RETURN_IF_ERROR
(
ValidateTensorMapping
(
ensemble_name
,
step
,
model_config
,
&
ensemble_tensors
));
}
// Visit nodes and validate decoupled workflow if any
// check data flow
size_t
decouple_label
=
0
;
std
::
deque
<
TensorNode
*>
current_iterators
;
for
(
const
auto
&
input
:
ensemble_config
.
input
())
{
auto
it
=
ensemble_tensors
.
find
(
input
.
name
());
it
->
second
.
visited_
=
true
;
current_iterators
.
push_back
(
&
(
it
->
second
));
}
while
(
!
current_iterators
.
empty
())
{
auto
&
current_node
=
current_iterators
.
front
();
for
(
auto
&
next_node
:
current_node
->
next_nodes_
)
{
if
(
next_node
->
visited_
)
{
continue
;
}
bool
next_node_ready
=
true
;
for
(
auto
&
prev_node
:
next_node
->
prev_nodes_
)
{
if
(
!
prev_node
->
visited_
)
{
next_node_ready
=
false
;
break
;
}
}
if
(
next_node_ready
)
{
size_t
prev_decouple_label
=
next_node
->
prev_nodes_
[
0
]
->
decouple_label_
;
for
(
auto
&
prev_node
:
next_node
->
prev_nodes_
)
{
if
(
prev_node
->
decouple_label_
!=
prev_decouple_label
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"in ensemble "
+
ensemble_name
+
", step of model '"
+
next_node
->
model_name_
+
"' receives inputs originated from different decoupled "
"models"
);
}
}
if
(
next_node
->
sibling_node_
!=
nullptr
)
{
if
(
next_node
->
sibling_node_
->
visited_
)
{
next_node
->
decouple_label_
=
next_node
->
sibling_node_
->
decouple_label_
;
}
else
{
next_node
->
decouple_label_
=
next_node
->
is_decoupled_
?
++
decouple_label
:
prev_decouple_label
;
next_node
->
sibling_node_
->
decouple_label_
=
next_node
->
decouple_label_
;
next_node
->
sibling_node_
->
visited_
=
true
;
}
}
else
{
next_node
->
decouple_label_
=
next_node
->
is_decoupled_
?
++
decouple_label
:
prev_decouple_label
;
}
next_node
->
visited_
=
true
;
current_iterators
.
push_back
(
next_node
);
}
}
current_iterators
.
pop_front
();
}
ensemble
->
model_config_
.
mutable_model_transaction_policy
()
->
set_decoupled
(
decouple_label
!=
0
);
return
Status
::
Success
;
}
}}
// namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
3rdparty/core-r22.12/src/ensemble_utils.h
0 → 100644
View file @
b30f3cdb
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_ENSEMBLE
#include <deque>
#include <unordered_map>
#include "model_config.pb.h"
#include "model_repository_manager.h"
#include "status.h"
#include "triton/common/model_config.h"
namespace
triton
{
namespace
core
{
/// Validate that the ensemble are specified correctly. Assuming that the
/// inputs and outputs specified in depending model configurations are accurate.
/// \param model_repository_manager The model manager to acquire model config.
/// \param ensemble The ensemble to be validated.
/// \return The error status.
Status
ValidateEnsembleConfig
(
ModelRepositoryManager
*
model_repository_manager
,
ModelRepositoryManager
::
DependencyNode
*
ensemble
);
}}
// namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
3rdparty/core-r22.12/src/filesystem.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "filesystem.h"
#ifdef _WIN32
// suppress the min and max definitions in Windef.h.
#define NOMINMAX
#include <Windows.h>
// _CRT_INTERNAL_NONSTDC_NAMES 1 before including Microsoft provided C Runtime
// library to expose declarations without "_" prefix to match POSIX style.
#define _CRT_INTERNAL_NONSTDC_NAMES 1
#include <direct.h>
#include <io.h>
#else
#include <dirent.h>
#include <unistd.h>
#endif
#ifdef TRITON_ENABLE_GCS
#include <google/cloud/storage/client.h>
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
#include <aws/core/Aws.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/s3/model/ListObjectsRequest.h>
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
#include <blob/blob_client.h>
#include <storage_account.h>
#include <storage_credential.h>
#undef LOG_INFO
#undef LOG_WARNING
#endif // TRITON_ENABLE_AZURE_STORAGE
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/text_format.h>
#include <re2/re2.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <cerrno>
#include <fstream>
#include <mutex>
#include "constants.h"
#include "status.h"
#include "triton/common/logging.h"
#define TRITONJSON_STATUSTYPE triton::core::Status
#define TRITONJSON_STATUSRETURN(M) \
return triton::core::Status(triton::core::Status::Code::INTERNAL, (M))
#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success
#include "triton/common/triton_json.h"
#ifdef _WIN32
// <sys/stat.h> in Windows doesn't define S_ISDIR macro
#if !defined(S_ISDIR) && defined(S_IFMT) && defined(S_IFDIR)
#define S_ISDIR(m) (((m)&S_IFMT) == S_IFDIR)
#endif
#define F_OK 0
#endif
namespace
triton
{
namespace
core
{
namespace
{
// Check if a local path is a directory. We need to use this in LocalFileSystem
// and LocalizedPath so have this common function.
Status
IsPathDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
{
*
is_dir
=
false
;
struct
stat
st
;
if
(
stat
(
path
.
c_str
(),
&
st
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to stat file "
+
path
);
}
*
is_dir
=
S_ISDIR
(
st
.
st_mode
);
return
Status
::
Success
;
}
}
// namespace
LocalizedPath
::~
LocalizedPath
()
{
if
(
!
local_path_
.
empty
())
{
bool
is_dir
=
true
;
IsDirectory
(
local_path_
,
&
is_dir
);
LOG_STATUS_ERROR
(
DeletePath
(
is_dir
?
local_path_
:
DirName
(
local_path_
)),
"failed to delete localized path"
);
}
}
namespace
{
class
FileSystem
{
public:
virtual
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
=
0
;
virtual
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
=
0
;
virtual
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
=
0
;
virtual
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
=
0
;
virtual
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
=
0
;
virtual
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
=
0
;
virtual
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
=
0
;
virtual
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
=
0
;
virtual
Status
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
=
0
;
virtual
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
=
0
;
virtual
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
=
0
;
virtual
Status
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
=
0
;
virtual
Status
DeletePath
(
const
std
::
string
&
path
)
=
0
;
};
class
LocalFileSystem
:
public
FileSystem
{
public:
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
override
;
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
override
;
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
override
;
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
override
;
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
override
;
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
override
;
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
override
;
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
override
;
Status
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
override
;
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
override
;
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
override
;
Status
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
override
;
Status
DeletePath
(
const
std
::
string
&
path
)
override
;
};
Status
LocalFileSystem
::
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
{
*
exists
=
(
access
(
path
.
c_str
(),
F_OK
)
==
0
);
return
Status
::
Success
;
}
Status
LocalFileSystem
::
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
{
return
IsPathDirectory
(
path
,
is_dir
);
}
Status
LocalFileSystem
::
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
{
struct
stat
st
;
if
(
stat
(
path
.
c_str
(),
&
st
)
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to stat file "
+
path
);
}
#ifdef _WIN32
// In Windows, st_mtime is in time_t
*
mtime_ns
=
std
::
max
(
st
.
st_mtime
,
st
.
st_ctime
);
#else
*
mtime_ns
=
std
::
max
(
TIMESPEC_TO_NANOS
(
st
.
st_mtim
),
TIMESPEC_TO_NANOS
(
st
.
st_ctim
));
#endif
return
Status
::
Success
;
}
Status
LocalFileSystem
::
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
{
#ifdef _WIN32
WIN32_FIND_DATA
entry
;
// Append "*" to obtain all files under 'path'
HANDLE
dir
=
FindFirstFile
(
JoinPath
({
path
,
"*"
}).
c_str
(),
&
entry
);
if
(
dir
==
INVALID_HANDLE_VALUE
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to open directory "
+
path
);
}
if
((
strcmp
(
entry
.
cFileName
,
"."
)
!=
0
)
&&
(
strcmp
(
entry
.
cFileName
,
".."
)
!=
0
))
{
contents
->
insert
(
entry
.
cFileName
);
}
while
(
FindNextFile
(
dir
,
&
entry
))
{
if
((
strcmp
(
entry
.
cFileName
,
"."
)
!=
0
)
&&
(
strcmp
(
entry
.
cFileName
,
".."
)
!=
0
))
{
contents
->
insert
(
entry
.
cFileName
);
}
}
FindClose
(
dir
);
#else
DIR
*
dir
=
opendir
(
path
.
c_str
());
if
(
dir
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to open directory "
+
path
);
}
struct
dirent
*
entry
;
while
((
entry
=
readdir
(
dir
))
!=
nullptr
)
{
std
::
string
entryname
=
entry
->
d_name
;
if
((
entryname
!=
"."
)
&&
(
entryname
!=
".."
))
{
contents
->
insert
(
entryname
);
}
}
closedir
(
dir
);
#endif
return
Status
::
Success
;
}
Status
LocalFileSystem
::
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
{
RETURN_IF_ERROR
(
GetDirectoryContents
(
path
,
subdirs
));
// Erase non-directory entries...
for
(
auto
iter
=
subdirs
->
begin
();
iter
!=
subdirs
->
end
();)
{
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
path
,
*
iter
}),
&
is_dir
));
if
(
!
is_dir
)
{
iter
=
subdirs
->
erase
(
iter
);
}
else
{
++
iter
;
}
}
return
Status
::
Success
;
}
Status
LocalFileSystem
::
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
{
RETURN_IF_ERROR
(
GetDirectoryContents
(
path
,
files
));
// Erase directory entries...
for
(
auto
iter
=
files
->
begin
();
iter
!=
files
->
end
();)
{
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
path
,
*
iter
}),
&
is_dir
));
if
(
is_dir
)
{
iter
=
files
->
erase
(
iter
);
}
else
{
++
iter
;
}
}
return
Status
::
Success
;
}
Status
LocalFileSystem
::
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
{
std
::
ifstream
in
(
path
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
in
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to open text file for read "
+
path
+
": "
+
strerror
(
errno
));
}
in
.
seekg
(
0
,
std
::
ios
::
end
);
contents
->
resize
(
in
.
tellg
());
in
.
seekg
(
0
,
std
::
ios
::
beg
);
in
.
read
(
&
(
*
contents
)[
0
],
contents
->
size
());
in
.
close
();
return
Status
::
Success
;
}
Status
LocalFileSystem
::
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
{
// For local file system we don't actually need to download the
// directory or file. We use it in place.
localized
->
reset
(
new
LocalizedPath
(
path
));
return
Status
::
Success
;
}
Status
LocalFileSystem
::
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
{
std
::
ofstream
out
(
path
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
if
(
!
out
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to open text file for write "
+
path
+
": "
+
strerror
(
errno
));
}
out
.
write
(
&
contents
[
0
],
contents
.
size
());
out
.
close
();
return
Status
::
Success
;
}
Status
LocalFileSystem
::
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
{
std
::
ofstream
out
(
path
,
std
::
ios
::
out
|
std
::
ios
::
binary
);
if
(
!
out
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to open binary file for write "
+
path
+
": "
+
strerror
(
errno
));
}
out
.
write
(
contents
,
content_len
);
return
Status
::
Success
;
}
Status
LocalFileSystem
::
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
{
#ifdef _WIN32
if
(
mkdir
(
dir
.
c_str
())
==
-
1
)
#else
if
(
mkdir
(
dir
.
c_str
(),
S_IRWXU
)
==
-
1
)
#endif
{
// Only allow the error due to parent directory does not exist
// if 'recursive' is requested
if
((
errno
==
ENOENT
)
&&
(
!
dir
.
empty
())
&&
recursive
)
{
RETURN_IF_ERROR
(
MakeDirectory
(
DirName
(
dir
),
recursive
));
// Retry the creation
#ifdef _WIN32
if
(
mkdir
(
dir
.
c_str
())
==
-
1
)
#else
if
(
mkdir
(
dir
.
c_str
(),
S_IRWXU
)
==
-
1
)
#endif
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create directory '"
+
dir
+
"', errno:"
+
strerror
(
errno
));
}
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create directory '"
+
dir
+
"', errno:"
+
strerror
(
errno
));
}
}
return
Status
::
Success
;
}
Status
LocalFileSystem
::
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
{
#ifdef _WIN32
char
temp_path
[
MAX_PATH
+
1
];
size_t
temp_path_length
=
GetTempPath
(
MAX_PATH
+
1
,
temp_path
);
if
(
temp_path_length
==
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get local directory for temporary files"
);
}
// There is no single operation like 'mkdtemp' in Windows, thus generating
// unique temporary directory is a process of getting temporary file name,
// deleting the file (file creation is side effect fo getting name), creating
// corresponding directory, so mutex is used to avoid possible race condition.
// However, it doesn't prevent other process on creating temporary file and
// thus the race condition may still happen. One possible solution is
// to reserve a temporary directory for the process and generate temporary
// model directories inside it.
static
std
::
mutex
mtx
;
std
::
lock_guard
<
std
::
mutex
>
lk
(
mtx
);
// Construct a std::string as filled 'temp_path' is not C string,
// and so that we can reuse 'temp_path' to hold the temp file name.
std
::
string
temp_path_str
(
temp_path
,
temp_path_length
);
if
(
GetTempFileName
(
temp_path_str
.
c_str
(),
"folder"
,
0
,
temp_path
)
==
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local temp folder"
);
}
*
temp_dir
=
temp_path
;
DeleteFile
(
temp_dir
->
c_str
());
if
(
CreateDirectory
(
temp_dir
->
c_str
(),
NULL
)
==
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local temp folder: "
+
*
temp_dir
);
}
#else
std
::
string
folder_template
=
"/tmp/folderXXXXXX"
;
char
*
res
=
mkdtemp
(
const_cast
<
char
*>
(
folder_template
.
c_str
()));
if
(
res
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local temp folder: "
+
folder_template
+
", errno:"
+
strerror
(
errno
));
}
*
temp_dir
=
res
;
#endif
return
Status
::
Success
;
}
Status
LocalFileSystem
::
DeletePath
(
const
std
::
string
&
path
)
{
bool
is_dir
=
false
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
is_dir
)
{
std
::
set
<
std
::
string
>
contents
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
path
,
&
contents
));
for
(
const
auto
&
content
:
contents
)
{
RETURN_IF_ERROR
(
DeletePath
(
JoinPath
({
path
,
content
})));
}
rmdir
(
path
.
c_str
());
}
else
{
remove
(
path
.
c_str
());
}
return
Status
::
Success
;
}
#if defined(TRITON_ENABLE_GCS) || defined(TRITON_ENABLE_S3) || \
defined(TRITON_ENABLE_AZURE_STORAGE)
// Helper function to take care of lack of trailing slashes
std
::
string
AppendSlash
(
const
std
::
string
&
name
)
{
if
(
name
.
empty
()
||
(
name
.
back
()
==
'/'
))
{
return
name
;
}
return
(
name
+
"/"
);
}
#endif // TRITON_ENABLE_GCS || TRITON_ENABLE_S3 || TRITON_ENABLE_AZURE_STORAGE
#ifdef TRITON_ENABLE_GCS
namespace
gcs
=
google
::
cloud
::
storage
;
struct
GCSCredential
{
std
::
string
path_
;
GCSCredential
();
// from env var
GCSCredential
(
triton
::
common
::
TritonJson
::
Value
&
cred_json
);
};
GCSCredential
::
GCSCredential
()
{
const
char
*
path
=
std
::
getenv
(
"GOOGLE_APPLICATION_CREDENTIALS"
);
path_
=
(
path
!=
nullptr
?
std
::
string
(
path
)
:
""
);
}
GCSCredential
::
GCSCredential
(
triton
::
common
::
TritonJson
::
Value
&
cred_json
)
{
cred_json
.
AsString
(
&
path_
);
}
class
GCSFileSystem
:
public
FileSystem
{
public:
GCSFileSystem
(
const
GCSCredential
&
gs_cred
);
// unify with S3/azure interface
GCSFileSystem
(
const
std
::
string
&
path
,
const
GCSCredential
&
gs_cred
)
:
GCSFileSystem
(
gs_cred
)
{
}
Status
CheckClient
();
// unify with S3 interface
Status
CheckClient
(
const
std
::
string
&
path
)
{
return
CheckClient
();
}
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
override
;
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
override
;
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
override
;
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
override
;
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
override
;
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
override
;
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
override
;
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
override
;
Status
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
override
;
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
override
;
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
override
;
Status
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
override
;
Status
DeletePath
(
const
std
::
string
&
path
)
override
;
private:
Status
ParsePath
(
const
std
::
string
&
path
,
std
::
string
*
bucket
,
std
::
string
*
object
);
Status
MetaDataExists
(
const
std
::
string
path
,
bool
*
exists
,
google
::
cloud
::
StatusOr
<
gcs
::
ObjectMetadata
>*
metadata
);
google
::
cloud
::
StatusOr
<
gcs
::
Client
>
client_
;
};
GCSFileSystem
::
GCSFileSystem
(
const
GCSCredential
&
gs_cred
)
{
auto
creds
=
gcs
::
oauth2
::
CreateServiceAccountCredentialsFromJsonFilePath
(
gs_cred
.
path_
);
if
(
creds
)
{
client_
=
gcs
::
Client
(
gcs
::
ClientOptions
(
*
creds
));
}
}
Status
GCSFileSystem
::
CheckClient
()
{
if
(
!
client_
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unable to create GCS client. Check account credentials."
);
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
ParsePath
(
const
std
::
string
&
path
,
std
::
string
*
bucket
,
std
::
string
*
object
)
{
// Get the bucket name and the object path. Return error if input is malformed
int
bucket_start
=
path
.
find
(
"gs://"
)
+
strlen
(
"gs://"
);
int
bucket_end
=
path
.
find
(
"/"
,
bucket_start
);
// If there isn't a second slash, the address has only the bucket
if
(
bucket_end
>
bucket_start
)
{
*
bucket
=
path
.
substr
(
bucket_start
,
bucket_end
-
bucket_start
);
*
object
=
path
.
substr
(
bucket_end
+
1
);
}
else
{
*
bucket
=
path
.
substr
(
bucket_start
);
*
object
=
""
;
}
if
(
bucket
->
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"No bucket name found in path: "
+
path
);
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
{
*
exists
=
false
;
std
::
string
bucket
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object
));
// Make a request for metadata and check the response
google
::
cloud
::
StatusOr
<
gcs
::
ObjectMetadata
>
object_metadata
=
client_
->
GetObjectMetadata
(
bucket
,
object
);
if
(
object_metadata
)
{
*
exists
=
true
;
return
Status
::
Success
;
}
// GCS doesn't make objects for directories, so it could still be a directory
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
*
exists
=
is_dir
;
return
Status
::
Success
;
}
Status
GCSFileSystem
::
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
{
*
is_dir
=
false
;
std
::
string
bucket
,
object_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object_path
));
// Check if the bucket exists
google
::
cloud
::
StatusOr
<
gcs
::
BucketMetadata
>
bucket_metadata
=
client_
->
GetBucketMetadata
(
bucket
);
if
(
!
bucket_metadata
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Could not get MetaData for bucket with name "
+
bucket
+
" : "
+
bucket_metadata
.
status
().
message
());
}
// Root case - bucket exists and object path is empty
if
(
object_path
.
empty
())
{
*
is_dir
=
true
;
return
Status
::
Success
;
}
// Check whether it has children. If at least one child, it is a directory
for
(
auto
&&
object_metadata
:
client_
->
ListObjects
(
bucket
,
gcs
::
Prefix
(
AppendSlash
(
object_path
))))
{
if
(
object_metadata
)
{
*
is_dir
=
true
;
break
;
}
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
{
// We don't need to worry about the case when this is a directory
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
is_dir
)
{
*
mtime_ns
=
0
;
return
Status
::
Success
;
}
std
::
string
bucket
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object
));
// Otherwise check the object metadata for update time
google
::
cloud
::
StatusOr
<
gcs
::
ObjectMetadata
>
object_metadata
=
client_
->
GetObjectMetadata
(
bucket
,
object
);
if
(
!
object_metadata
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get metadata for "
+
object
+
" : "
+
object_metadata
.
status
().
message
());
}
// Get duration from time point with respect to object clock
auto
update_time
=
std
::
chrono
::
time_point_cast
<
std
::
chrono
::
nanoseconds
>
(
object_metadata
->
updated
())
.
time_since_epoch
()
.
count
();
*
mtime_ns
=
update_time
;
return
Status
::
Success
;
}
Status
GCSFileSystem
::
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
{
std
::
string
bucket
,
dir_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
dir_path
));
// Append a slash to make it easier to list contents
std
::
string
full_dir
=
AppendSlash
(
dir_path
);
// Get objects with prefix equal to full directory path
for
(
auto
&&
object_metadata
:
client_
->
ListObjects
(
bucket
,
gcs
::
Prefix
(
full_dir
)))
{
if
(
!
object_metadata
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Could not list contents of directory at "
+
path
+
" : "
+
object_metadata
.
status
().
message
());
}
// In the case of empty directories, the directory itself will appear here
if
(
object_metadata
->
name
()
==
full_dir
)
{
continue
;
}
// We have to make sure that subdirectory contents do not appear here
std
::
string
name
=
object_metadata
->
name
();
int
item_start
=
name
.
find
(
full_dir
)
+
full_dir
.
size
();
// GCS response prepends parent directory name
int
item_end
=
name
.
find
(
"/"
,
item_start
);
// Let set take care of subdirectory contents
std
::
string
item
=
name
.
substr
(
item_start
,
item_end
-
item_start
);
contents
->
insert
(
item
);
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
{
RETURN_IF_ERROR
(
GetDirectoryContents
(
path
,
subdirs
));
// Erase non-directory entries...
for
(
auto
iter
=
subdirs
->
begin
();
iter
!=
subdirs
->
end
();)
{
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
path
,
*
iter
}),
&
is_dir
));
if
(
!
is_dir
)
{
iter
=
subdirs
->
erase
(
iter
);
}
else
{
++
iter
;
}
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
{
RETURN_IF_ERROR
(
GetDirectoryContents
(
path
,
files
));
// Erase directory entries...
for
(
auto
iter
=
files
->
begin
();
iter
!=
files
->
end
();)
{
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
path
,
*
iter
}),
&
is_dir
));
if
(
is_dir
)
{
iter
=
files
->
erase
(
iter
);
}
else
{
++
iter
;
}
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
{
bool
exists
;
RETURN_IF_ERROR
(
FileExists
(
path
,
&
exists
));
if
(
!
exists
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"File does not exist at "
+
path
);
}
std
::
string
bucket
,
object
;
ParsePath
(
path
,
&
bucket
,
&
object
);
gcs
::
ObjectReadStream
stream
=
client_
->
ReadObject
(
bucket
,
object
);
if
(
!
stream
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to open object read stream for "
+
path
+
" : "
+
stream
.
status
().
message
());
}
std
::
string
data
=
""
;
char
c
;
while
(
stream
.
get
(
c
))
{
data
+=
c
;
}
*
contents
=
data
;
return
Status
::
Success
;
}
Status
GCSFileSystem
::
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
{
bool
exists
;
RETURN_IF_ERROR
(
FileExists
(
path
,
&
exists
));
if
(
!
exists
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"directory or file does not exist at "
+
path
);
}
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
!
is_dir
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"GCS file localization not yet implemented "
+
path
);
}
std
::
string
tmp_folder
;
RETURN_IF_ERROR
(
triton
::
core
::
MakeTemporaryDirectory
(
FileSystemType
::
LOCAL
,
&
tmp_folder
));
localized
->
reset
(
new
LocalizedPath
(
path
,
tmp_folder
));
std
::
set
<
std
::
string
>
contents
,
filenames
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
path
,
&
filenames
));
for
(
auto
itr
=
filenames
.
begin
();
itr
!=
filenames
.
end
();
++
itr
)
{
contents
.
insert
(
JoinPath
({
path
,
*
itr
}));
}
while
(
contents
.
size
()
!=
0
)
{
std
::
set
<
std
::
string
>
tmp_contents
=
contents
;
contents
.
clear
();
for
(
auto
iter
=
tmp_contents
.
begin
();
iter
!=
tmp_contents
.
end
();
++
iter
)
{
bool
is_subdir
;
std
::
string
gcs_fpath
=
*
iter
;
std
::
string
gcs_removed_path
=
gcs_fpath
.
substr
(
path
.
size
());
std
::
string
local_fpath
=
JoinPath
({(
*
localized
)
->
Path
(),
gcs_removed_path
});
RETURN_IF_ERROR
(
IsDirectory
(
gcs_fpath
,
&
is_subdir
));
if
(
is_subdir
)
{
// Create local mirror of sub-directories
#ifdef _WIN32
int
status
=
mkdir
(
const_cast
<
char
*>
(
local_fpath
.
c_str
()));
#else
int
status
=
mkdir
(
const_cast
<
char
*>
(
local_fpath
.
c_str
()),
S_IRUSR
|
S_IWUSR
|
S_IXUSR
);
#endif
if
(
status
==
-
1
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local folder: "
+
local_fpath
+
", errno:"
+
strerror
(
errno
));
}
// Add sub-directories and deeper files to contents
std
::
set
<
std
::
string
>
subdir_contents
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
gcs_fpath
,
&
subdir_contents
));
for
(
auto
itr
=
subdir_contents
.
begin
();
itr
!=
subdir_contents
.
end
();
++
itr
)
{
contents
.
insert
(
JoinPath
({
gcs_fpath
,
*
itr
}));
}
}
else
{
// Create local copy of file
std
::
string
file_bucket
,
file_object
;
RETURN_IF_ERROR
(
ParsePath
(
gcs_fpath
,
&
file_bucket
,
&
file_object
));
// Send a request to read the object
gcs
::
ObjectReadStream
filestream
=
client_
->
ReadObject
(
file_bucket
,
file_object
);
if
(
!
filestream
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get object at "
+
*
iter
+
" : "
+
filestream
.
status
().
message
());
}
std
::
string
gcs_removed_path
=
(
*
iter
).
substr
(
path
.
size
());
std
::
string
local_file_path
=
JoinPath
({(
*
localized
)
->
Path
(),
gcs_removed_path
});
std
::
ofstream
output_file
(
local_file_path
.
c_str
(),
std
::
ios
::
binary
);
output_file
<<
filestream
.
rdbuf
();
output_file
.
close
();
}
}
}
return
Status
::
Success
;
}
Status
GCSFileSystem
::
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Write text file operation not yet implemented "
+
path
);
}
Status
GCSFileSystem
::
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Write text file operation not yet implemented "
+
path
);
}
Status
GCSFileSystem
::
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Make temporary directory operation not yet implemented"
);
}
Status
GCSFileSystem
::
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Make temporary directory operation not yet implemented"
);
}
Status
GCSFileSystem
::
DeletePath
(
const
std
::
string
&
path
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Delete path operation not yet implemented"
);
}
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_AZURE_STORAGE
namespace
as
=
azure
::
storage_lite
;
const
std
::
string
AS_URL_PATTERN
=
"as://([^/]+)/([^/?]+)(?:/([^?]*))?(
\\
?.*)?"
;
struct
ASCredential
{
std
::
string
account_str_
;
std
::
string
account_key_
;
ASCredential
();
// from env var
ASCredential
(
triton
::
common
::
TritonJson
::
Value
&
cred_json
);
};
ASCredential
::
ASCredential
()
{
const
auto
to_str
=
[](
const
char
*
s
)
->
std
::
string
{
return
(
s
!=
nullptr
?
std
::
string
(
s
)
:
""
);
};
const
char
*
account_str
=
std
::
getenv
(
"AZURE_STORAGE_ACCOUNT"
);
const
char
*
account_key
=
std
::
getenv
(
"AZURE_STORAGE_KEY"
);
account_str_
=
to_str
(
account_str
);
account_key_
=
to_str
(
account_key
);
}
ASCredential
::
ASCredential
(
triton
::
common
::
TritonJson
::
Value
&
cred_json
)
{
triton
::
common
::
TritonJson
::
Value
account_str_json
,
account_key_json
;
if
(
cred_json
.
Find
(
"account_str"
,
&
account_str_json
))
account_str_json
.
AsString
(
&
account_str_
);
if
(
cred_json
.
Find
(
"account_key"
,
&
account_key_json
))
account_key_json
.
AsString
(
&
account_key_
);
}
class
ASFileSystem
:
public
FileSystem
{
public:
ASFileSystem
(
const
std
::
string
&
path
,
const
ASCredential
&
as_cred
);
Status
CheckClient
();
// unify with S3 interface
Status
CheckClient
(
const
std
::
string
&
path
)
{
return
CheckClient
();
}
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
override
;
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
override
;
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
override
;
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
override
;
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
override
;
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
override
;
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
override
;
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
override
;
Status
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
override
;
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
override
;
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
override
;
Status
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
override
;
Status
DeletePath
(
const
std
::
string
&
path
)
override
;
private:
Status
ParsePath
(
const
std
::
string
&
path
,
std
::
string
*
bucket
,
std
::
string
*
object
);
std
::
shared_ptr
<
as
::
blob_client
>
client_
;
Status
ListDirectory
(
const
std
::
string
&
path
,
const
std
::
string
&
dir_path
,
std
::
function
<
Status
(
const
as
::
list_blobs_segmented_item
&
,
const
std
::
string
&
)
>
func
);
Status
DownloadFolder
(
const
std
::
string
&
container
,
const
std
::
string
&
path
,
const
std
::
string
&
dest
);
re2
::
RE2
as_regex_
;
};
Status
ASFileSystem
::
ParsePath
(
const
std
::
string
&
path
,
std
::
string
*
container
,
std
::
string
*
object
)
{
std
::
string
host_name
,
query
;
if
(
!
RE2
::
FullMatch
(
path
,
as_regex_
,
&
host_name
,
container
,
object
,
&
query
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Invalid azure storage path: "
+
path
);
}
return
Status
::
Success
;
}
ASFileSystem
::
ASFileSystem
(
const
std
::
string
&
path
,
const
ASCredential
&
as_cred
)
:
as_regex_
(
AS_URL_PATTERN
)
{
std
::
shared_ptr
<
as
::
storage_account
>
account
=
nullptr
;
std
::
string
host_name
,
container
,
blob_path
,
query
;
if
(
RE2
::
FullMatch
(
path
,
as_regex_
,
&
host_name
,
&
container
,
&
blob_path
,
&
query
))
{
size_t
pos
=
host_name
.
rfind
(
".blob.core.windows.net"
);
std
::
string
account_name
;
if
(
as_cred
.
account_str_
.
empty
())
{
if
(
pos
!=
std
::
string
::
npos
)
{
account_name
=
host_name
.
substr
(
0
,
pos
);
}
else
{
account_name
=
host_name
;
}
}
else
{
account_name
=
as_cred
.
account_str_
;
}
std
::
shared_ptr
<
as
::
storage_credential
>
cred
;
if
(
!
as_cred
.
account_key_
.
empty
())
{
// Shared Key
cred
=
std
::
make_shared
<
as
::
shared_key_credential
>
(
account_name
,
as_cred
.
account_key_
);
}
else
{
cred
=
std
::
make_shared
<
as
::
anonymous_credential
>
();
}
account
=
std
::
make_shared
<
as
::
storage_account
>
(
account_name
,
cred
,
/* use_https */
true
);
client_
=
std
::
make_shared
<
as
::
blob_client
>
(
account
,
/*max_concurrency*/
16
);
}
}
Status
ASFileSystem
::
CheckClient
()
{
if
(
client_
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unable to create Azure filesystem client. Check account credentials."
);
}
return
Status
::
Success
;
}
Status
ASFileSystem
::
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
{
as
::
blob_client_wrapper
bc
(
client_
);
std
::
string
container
,
object_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
object_path
));
auto
blobProperty
=
bc
.
get_blob_property
(
container
,
object_path
);
if
(
errno
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unable to get blob property for file at "
+
path
+
", errno:"
+
strerror
(
errno
));
}
auto
time
=
std
::
chrono
::
system_clock
::
from_time_t
(
blobProperty
.
last_modified
);
auto
update_time
=
std
::
chrono
::
time_point_cast
<
std
::
chrono
::
nanoseconds
>
(
time
)
.
time_since_epoch
()
.
count
();
*
mtime_ns
=
update_time
;
return
Status
::
Success
;
};
Status
ASFileSystem
::
ListDirectory
(
const
std
::
string
&
container
,
const
std
::
string
&
dir_path
,
std
::
function
<
Status
(
const
as
::
list_blobs_segmented_item
&
,
const
std
::
string
&
)
>
func
)
{
as
::
blob_client_wrapper
bc
(
client_
);
// Append a slash to make it easier to list contents
std
::
string
full_dir
=
AppendSlash
(
dir_path
);
auto
blobs
=
bc
.
list_blobs_segmented
(
container
,
"/"
,
""
,
full_dir
);
if
(
errno
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get contents of directory "
+
dir_path
+
", errno:"
+
strerror
(
errno
));
}
for
(
auto
&&
item
:
blobs
.
blobs
)
{
std
::
string
name
=
item
.
name
;
int
item_start
=
name
.
find
(
full_dir
)
+
full_dir
.
size
();
int
item_end
=
name
.
find
(
"/"
,
item_start
);
// Let set take care of subdirectory contents
std
::
string
subfile
=
name
.
substr
(
item_start
,
item_end
-
item_start
);
auto
status
=
func
(
item
,
subfile
);
if
(
!
status
.
IsOk
())
{
return
status
;
}
}
return
Status
::
Success
;
}
Status
ASFileSystem
::
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
{
auto
func
=
[
&
](
const
as
::
list_blobs_segmented_item
&
item
,
const
std
::
string
&
dir
)
{
contents
->
insert
(
dir
);
return
Status
::
Success
;
};
std
::
string
container
,
dir_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
dir_path
));
return
ListDirectory
(
container
,
dir_path
,
func
);
}
Status
ASFileSystem
::
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
{
auto
func
=
[
&
](
const
as
::
list_blobs_segmented_item
&
item
,
const
std
::
string
&
dir
)
{
if
(
item
.
is_directory
)
{
subdirs
->
insert
(
dir
);
}
return
Status
::
Success
;
};
std
::
string
container
,
dir_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
dir_path
));
return
ListDirectory
(
container
,
dir_path
,
func
);
}
Status
ASFileSystem
::
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
{
auto
func
=
[
&
](
const
as
::
list_blobs_segmented_item
&
item
,
const
std
::
string
&
file
)
{
if
(
!
item
.
is_directory
)
{
files
->
insert
(
file
);
}
return
Status
::
Success
;
};
std
::
string
container
,
dir_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
dir_path
));
return
ListDirectory
(
container
,
dir_path
,
func
);
}
Status
ASFileSystem
::
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
{
*
is_dir
=
false
;
std
::
string
container
,
object_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
object_path
));
as
::
blob_client_wrapper
bc
(
client_
);
auto
blobs
=
bc
.
list_blobs_segmented
(
container
,
"/"
,
""
,
object_path
,
1
);
if
(
errno
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to check if directory at "
+
path
+
", errno:"
+
strerror
(
errno
));
}
*
is_dir
=
blobs
.
blobs
.
size
()
>
0
;
return
Status
::
Success
;
};
Status
ASFileSystem
::
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
{
as
::
blob_client_wrapper
bc
(
client_
);
std
::
string
container
,
object_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
object_path
));
using
namespace
azure
::
storage_lite
;
std
::
ostringstream
out_stream
;
bc
.
download_blob_to_stream
(
container
,
object_path
,
0
,
0
,
out_stream
);
if
(
errno
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to fetch file stream at "
+
path
+
", errno:"
+
strerror
(
errno
));
}
*
contents
=
out_stream
.
str
();
return
Status
::
Success
;
}
Status
ASFileSystem
::
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
{
*
exists
=
false
;
std
::
string
container
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
object
));
as
::
blob_client_wrapper
bc
(
client_
);
auto
blobs
=
bc
.
list_blobs_segmented
(
container
,
"/"
,
""
,
object
,
1
);
if
(
errno
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to check if file exists at "
+
path
+
", errno:"
+
strerror
(
errno
));
}
if
(
blobs
.
blobs
.
size
()
>
0
)
{
*
exists
=
true
;
}
return
Status
::
Success
;
}
Status
ASFileSystem
::
DownloadFolder
(
const
std
::
string
&
container
,
const
std
::
string
&
path
,
const
std
::
string
&
dest
)
{
as
::
blob_client_wrapper
bc
(
client_
);
auto
func
=
[
&
](
const
as
::
list_blobs_segmented_item
&
item
,
const
std
::
string
&
dir
)
{
auto
local_path
=
JoinPath
({
dest
,
dir
});
auto
blob_path
=
JoinPath
({
path
,
dir
});
if
(
item
.
is_directory
)
{
int
status
=
mkdir
(
const_cast
<
char
*>
(
local_path
.
c_str
()),
S_IRUSR
|
S_IWUSR
|
S_IXUSR
);
if
(
status
==
-
1
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local folder: "
+
local_path
+
", errno:"
+
strerror
(
errno
));
}
auto
ret
=
DownloadFolder
(
container
,
blob_path
,
local_path
);
if
(
!
ret
.
IsOk
())
{
return
ret
;
}
}
else
{
time_t
last_modified
;
bc
.
download_blob_to_file
(
container
,
blob_path
,
local_path
,
last_modified
);
if
(
errno
!=
0
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to download file at "
+
blob_path
+
", errno:"
+
strerror
(
errno
));
}
}
return
Status
::
Success
;
};
return
ListDirectory
(
container
,
path
,
func
);
}
Status
ASFileSystem
::
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
{
bool
exists
;
RETURN_IF_ERROR
(
FileExists
(
path
,
&
exists
));
if
(
!
exists
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"directory or file does not exist at "
+
path
);
}
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
!
is_dir
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"AS file localization not yet implemented "
+
path
);
}
std
::
string
folder_template
=
"/tmp/folderXXXXXX"
;
char
*
tmp_folder
=
mkdtemp
(
const_cast
<
char
*>
(
folder_template
.
c_str
()));
if
(
tmp_folder
==
nullptr
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local temp folder: "
+
folder_template
+
", errno:"
+
strerror
(
errno
));
}
localized
->
reset
(
new
LocalizedPath
(
path
,
tmp_folder
));
std
::
string
dest
(
folder_template
);
as
::
blob_client_wrapper
bc
(
client_
);
std
::
string
container
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
object
));
return
DownloadFolder
(
container
,
object
,
dest
);
}
Status
ASFileSystem
::
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
{
std
::
stringstream
ss
(
contents
);
std
::
istream
is
(
ss
.
rdbuf
());
std
::
string
container
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
container
,
&
object
));
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
metadata
;
auto
ret
=
client_
->
upload_block_blob_from_stream
(
container
,
object
,
is
,
metadata
)
.
get
();
if
(
!
ret
.
success
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to upload blob, Error: "
+
ret
.
error
().
code
+
", "
+
ret
.
error
().
code_name
);
}
return
Status
::
Success
;
}
Status
ASFileSystem
::
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Write text file operation not yet implemented "
+
path
);
}
Status
ASFileSystem
::
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Make directory operation not yet implemented"
);
}
Status
ASFileSystem
::
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Make temporary directory operation not yet implemented"
);
}
Status
ASFileSystem
::
DeletePath
(
const
std
::
string
&
path
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Delete path operation not yet implemented"
);
}
#endif // TRITON_ENABLE_AZURE_STORAGE
#ifdef TRITON_ENABLE_S3
namespace
s3
=
Aws
::
S3
;
struct
S3Credential
{
std
::
string
secret_key_
;
std
::
string
key_id_
;
std
::
string
region_
;
std
::
string
session_token_
;
std
::
string
profile_name_
;
S3Credential
();
// from env var
S3Credential
(
triton
::
common
::
TritonJson
::
Value
&
cred_json
);
};
S3Credential
::
S3Credential
()
{
const
auto
to_str
=
[](
const
char
*
s
)
->
std
::
string
{
return
(
s
!=
nullptr
?
std
::
string
(
s
)
:
""
);
};
const
char
*
secret_key
=
std
::
getenv
(
"AWS_SECRET_ACCESS_KEY"
);
const
char
*
key_id
=
std
::
getenv
(
"AWS_ACCESS_KEY_ID"
);
const
char
*
region
=
std
::
getenv
(
"AWS_DEFAULT_REGION"
);
const
char
*
session_token
=
std
::
getenv
(
"AWS_SESSION_TOKEN"
);
const
char
*
profile
=
std
::
getenv
(
"AWS_PROFILE"
);
secret_key_
=
to_str
(
secret_key
);
key_id_
=
to_str
(
key_id
);
region_
=
to_str
(
region
);
session_token_
=
to_str
(
session_token
);
profile_name_
=
to_str
(
profile
);
}
S3Credential
::
S3Credential
(
triton
::
common
::
TritonJson
::
Value
&
cred_json
)
{
triton
::
common
::
TritonJson
::
Value
secret_key_json
,
key_id_json
,
region_json
,
session_token_json
,
profile_json
;
if
(
cred_json
.
Find
(
"secret_key"
,
&
secret_key_json
))
secret_key_json
.
AsString
(
&
secret_key_
);
if
(
cred_json
.
Find
(
"key_id"
,
&
key_id_json
))
key_id_json
.
AsString
(
&
key_id_
);
if
(
cred_json
.
Find
(
"region"
,
&
region_json
))
region_json
.
AsString
(
&
region_
);
if
(
cred_json
.
Find
(
"session_token"
,
&
session_token_json
))
session_token_json
.
AsString
(
&
session_token_
);
if
(
cred_json
.
Find
(
"profile"
,
&
profile_json
))
profile_json
.
AsString
(
&
profile_name_
);
}
class
S3FileSystem
:
public
FileSystem
{
public:
S3FileSystem
(
const
std
::
string
&
s3_path
,
const
S3Credential
&
s3_cred
);
Status
CheckClient
(
const
std
::
string
&
s3_path
);
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
override
;
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
override
;
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
override
;
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
override
;
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
override
;
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
override
;
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
override
;
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
override
;
Status
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
override
;
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
override
;
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
override
;
Status
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
override
;
Status
DeletePath
(
const
std
::
string
&
path
)
override
;
private:
Status
ParsePath
(
const
std
::
string
&
path
,
std
::
string
*
bucket
,
std
::
string
*
object
);
Status
CleanPath
(
const
std
::
string
&
s3_path
,
std
::
string
*
clean_path
);
std
::
unique_ptr
<
s3
::
S3Client
>
client_
;
// init after Aws::InitAPI is called
re2
::
RE2
s3_regex_
;
};
Status
S3FileSystem
::
ParsePath
(
const
std
::
string
&
path
,
std
::
string
*
bucket
,
std
::
string
*
object
)
{
// Cleanup extra slashes
std
::
string
clean_path
;
RETURN_IF_ERROR
(
CleanPath
(
path
,
&
clean_path
));
// Get the bucket name and the object path. Return error if path is malformed
std
::
string
protocol
,
host_name
,
host_port
;
if
(
!
RE2
::
FullMatch
(
clean_path
,
s3_regex_
,
&
protocol
,
&
host_name
,
&
host_port
,
bucket
,
object
))
{
int
bucket_start
=
clean_path
.
find
(
"s3://"
)
+
strlen
(
"s3://"
);
int
bucket_end
=
clean_path
.
find
(
"/"
,
bucket_start
);
// If there isn't a slash, the address has only the bucket
if
(
bucket_end
>
bucket_start
)
{
*
bucket
=
clean_path
.
substr
(
bucket_start
,
bucket_end
-
bucket_start
);
*
object
=
clean_path
.
substr
(
bucket_end
+
1
);
}
else
{
*
bucket
=
clean_path
.
substr
(
bucket_start
);
*
object
=
""
;
}
}
else
{
// Erase leading '/' that is left behind in object name
if
((
*
object
)[
0
]
==
'/'
)
{
object
->
erase
(
0
,
1
);
}
}
if
(
bucket
->
empty
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"No bucket name found in path: "
+
path
);
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
CleanPath
(
const
std
::
string
&
s3_path
,
std
::
string
*
clean_path
)
{
// Must handle paths with s3 prefix
size_t
start
=
s3_path
.
find
(
"s3://"
);
std
::
string
path
=
""
;
if
(
start
!=
std
::
string
::
npos
)
{
path
=
s3_path
.
substr
(
start
+
strlen
(
"s3://"
));
*
clean_path
=
"s3://"
;
}
else
{
path
=
s3_path
;
*
clean_path
=
""
;
}
// Must handle paths with https:// or http:// prefix
size_t
https_start
=
path
.
find
(
"https://"
);
if
(
https_start
!=
std
::
string
::
npos
)
{
path
=
path
.
substr
(
https_start
+
strlen
(
"https://"
));
*
clean_path
+=
"https://"
;
}
else
{
size_t
http_start
=
path
.
find
(
"http://"
);
if
(
http_start
!=
std
::
string
::
npos
)
{
path
=
path
.
substr
(
http_start
+
strlen
(
"http://"
));
*
clean_path
+=
"http://"
;
}
}
// Remove trailing slashes
size_t
rtrim_length
=
path
.
find_last_not_of
(
'/'
);
if
(
rtrim_length
==
std
::
string
::
npos
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Invalid bucket name: '"
+
path
+
"'"
);
}
// Remove leading slashes
size_t
ltrim_length
=
path
.
find_first_not_of
(
'/'
);
if
(
ltrim_length
==
std
::
string
::
npos
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Invalid bucket name: '"
+
path
+
"'"
);
}
// Remove extra internal slashes
std
::
string
true_path
=
path
.
substr
(
ltrim_length
,
rtrim_length
+
1
);
std
::
vector
<
int
>
slash_locations
;
bool
previous_slash
=
false
;
for
(
size_t
i
=
0
;
i
<
true_path
.
size
();
i
++
)
{
if
(
true_path
[
i
]
==
'/'
)
{
if
(
!
previous_slash
)
{
*
clean_path
+=
true_path
[
i
];
}
previous_slash
=
true
;
}
else
{
*
clean_path
+=
true_path
[
i
];
previous_slash
=
false
;
}
}
return
Status
::
Success
;
}
S3FileSystem
::
S3FileSystem
(
const
std
::
string
&
s3_path
,
const
S3Credential
&
s3_cred
)
:
s3_regex_
(
"s3://(http://|https://|)([0-9a-zA-Z
\\
-.]+):([0-9]+)/"
"([0-9a-z.
\\
-]+)(((/[0-9a-zA-Z.
\\
-_]+)*)?)"
)
{
// init aws api if not already
Aws
::
SDKOptions
options
;
static
std
::
once_flag
onceFlag
;
std
::
call_once
(
onceFlag
,
[
&
options
]
{
Aws
::
InitAPI
(
options
);
});
Aws
::
Client
::
ClientConfiguration
config
;
Aws
::
Auth
::
AWSCredentials
credentials
;
// check vars for S3 credentials -> aws profile -> default
if
(
!
s3_cred
.
secret_key_
.
empty
()
&&
!
s3_cred
.
key_id_
.
empty
())
{
credentials
.
SetAWSAccessKeyId
(
s3_cred
.
key_id_
.
c_str
());
credentials
.
SetAWSSecretKey
(
s3_cred
.
secret_key_
.
c_str
());
if
(
!
s3_cred
.
session_token_
.
empty
())
{
credentials
.
SetSessionToken
(
s3_cred
.
session_token_
.
c_str
());
}
config
=
Aws
::
Client
::
ClientConfiguration
();
if
(
!
s3_cred
.
region_
.
empty
())
{
config
.
region
=
s3_cred
.
region_
.
c_str
();
}
}
else
if
(
!
s3_cred
.
profile_name_
.
empty
())
{
config
=
Aws
::
Client
::
ClientConfiguration
(
s3_cred
.
profile_name_
.
c_str
());
}
else
{
config
=
Aws
::
Client
::
ClientConfiguration
(
"default"
);
}
// Cleanup extra slashes
std
::
string
clean_path
;
LOG_STATUS_ERROR
(
CleanPath
(
s3_path
,
&
clean_path
),
"failed to parse S3 path"
);
std
::
string
protocol
,
host_name
,
host_port
,
bucket
,
object
;
if
(
RE2
::
FullMatch
(
clean_path
,
s3_regex_
,
&
protocol
,
&
host_name
,
&
host_port
,
&
bucket
,
&
object
))
{
config
.
endpointOverride
=
Aws
::
String
(
host_name
+
":"
+
host_port
);
if
(
protocol
==
"https://"
)
{
config
.
scheme
=
Aws
::
Http
::
Scheme
::
HTTPS
;
}
else
{
config
.
scheme
=
Aws
::
Http
::
Scheme
::
HTTP
;
}
}
if
(
!
s3_cred
.
secret_key_
.
empty
()
&&
!
s3_cred
.
key_id_
.
empty
())
{
client_
=
std
::
make_unique
<
s3
::
S3Client
>
(
credentials
,
config
,
Aws
::
Client
::
AWSAuthV4Signer
::
PayloadSigningPolicy
::
Never
,
/*useVirtualAdressing*/
false
);
}
else
{
client_
=
std
::
make_unique
<
s3
::
S3Client
>
(
config
,
Aws
::
Client
::
AWSAuthV4Signer
::
PayloadSigningPolicy
::
Never
,
/*useVirtualAdressing*/
false
);
}
}
Status
S3FileSystem
::
CheckClient
(
const
std
::
string
&
s3_path
)
{
std
::
string
bucket
,
object_path
;
RETURN_IF_ERROR
(
ParsePath
(
s3_path
,
&
bucket
,
&
object_path
));
// check if can connect to the bucket
s3
::
Model
::
HeadBucketRequest
head_request
;
head_request
.
WithBucket
(
bucket
.
c_str
());
if
(
!
client_
->
HeadBucket
(
head_request
).
IsSuccess
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Unable to create S3 filesystem client. Check account credentials."
);
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
{
*
exists
=
false
;
// S3 doesn't make objects for directories, so it could still be a directory
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
is_dir
)
{
*
exists
=
is_dir
;
return
Status
::
Success
;
}
std
::
string
bucket
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object
));
// Construct request for object metadata
s3
::
Model
::
HeadObjectRequest
head_request
;
head_request
.
SetBucket
(
bucket
.
c_str
());
head_request
.
SetKey
(
object
.
c_str
());
auto
head_object_outcome
=
client_
->
HeadObject
(
head_request
);
if
(
!
head_object_outcome
.
IsSuccess
())
{
if
(
head_object_outcome
.
GetError
().
GetErrorType
()
!=
s3
::
S3Errors
::
RESOURCE_NOT_FOUND
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Could not get MetaData for object at "
+
path
+
" due to exception: "
+
head_object_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
head_object_outcome
.
GetError
().
GetMessage
());
}
}
else
{
*
exists
=
true
;
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
{
*
is_dir
=
false
;
std
::
string
bucket
,
object_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object_path
));
// Check if the bucket exists
s3
::
Model
::
HeadBucketRequest
head_request
;
head_request
.
WithBucket
(
bucket
.
c_str
());
auto
head_bucket_outcome
=
client_
->
HeadBucket
(
head_request
);
if
(
!
head_bucket_outcome
.
IsSuccess
())
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Could not get MetaData for bucket with name "
+
bucket
+
" due to exception: "
+
head_bucket_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
head_bucket_outcome
.
GetError
().
GetMessage
());
}
// Root case - bucket exists and object path is empty
if
(
object_path
.
empty
())
{
*
is_dir
=
true
;
return
Status
::
Success
;
}
// List the objects in the bucket
s3
::
Model
::
ListObjectsRequest
list_objects_request
;
list_objects_request
.
SetBucket
(
bucket
.
c_str
());
list_objects_request
.
SetPrefix
(
AppendSlash
(
object_path
).
c_str
());
auto
list_objects_outcome
=
client_
->
ListObjects
(
list_objects_request
);
if
(
list_objects_outcome
.
IsSuccess
())
{
*
is_dir
=
!
list_objects_outcome
.
GetResult
().
GetContents
().
empty
();
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to list objects with prefix "
+
path
+
" due to exception: "
+
list_objects_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
list_objects_outcome
.
GetError
().
GetMessage
());
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
{
// We don't need to worry about the case when this is a directory
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
is_dir
)
{
*
mtime_ns
=
0
;
return
Status
::
Success
;
}
std
::
string
bucket
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object
));
// Send a request for the objects metadata
s3
::
Model
::
HeadObjectRequest
head_request
;
head_request
.
SetBucket
(
bucket
.
c_str
());
head_request
.
SetKey
(
object
.
c_str
());
// If request succeeds, copy over the modification time
auto
head_object_outcome
=
client_
->
HeadObject
(
head_request
);
if
(
head_object_outcome
.
IsSuccess
())
{
*
mtime_ns
=
head_object_outcome
.
GetResult
().
GetLastModified
().
Millis
()
*
NANOS_PER_MILLIS
;
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get modification time for object at "
+
path
+
" due to exception: "
+
head_object_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
head_object_outcome
.
GetError
().
GetMessage
());
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
{
// Parse bucket and dir_path
std
::
string
bucket
,
dir_path
,
full_dir
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
dir_path
));
std
::
string
true_path
=
"s3://"
+
bucket
+
'/'
+
dir_path
;
// Capture the full path to facilitate content listing
full_dir
=
AppendSlash
(
dir_path
);
// Issue request for objects with prefix
s3
::
Model
::
ListObjectsRequest
objects_request
;
objects_request
.
SetBucket
(
bucket
.
c_str
());
objects_request
.
SetPrefix
(
full_dir
.
c_str
());
auto
list_objects_outcome
=
client_
->
ListObjects
(
objects_request
);
if
(
list_objects_outcome
.
IsSuccess
())
{
Aws
::
Vector
<
Aws
::
S3
::
Model
::
Object
>
object_list
=
list_objects_outcome
.
GetResult
().
GetContents
();
for
(
auto
const
&
s3_object
:
object_list
)
{
// In the case of empty directories, the directory itself will appear here
if
(
s3_object
.
GetKey
().
c_str
()
==
full_dir
)
{
continue
;
}
// We have to make sure that subdirectory contents do not appear here
std
::
string
name
(
s3_object
.
GetKey
().
c_str
());
int
item_start
=
name
.
find
(
full_dir
)
+
full_dir
.
size
();
// S3 response prepends parent directory name
int
item_end
=
name
.
find
(
"/"
,
item_start
);
// Let set take care of subdirectory contents
std
::
string
item
=
name
.
substr
(
item_start
,
item_end
-
item_start
);
contents
->
insert
(
item
);
}
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Could not list contents of directory at "
+
true_path
+
" due to exception: "
+
list_objects_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
list_objects_outcome
.
GetError
().
GetMessage
());
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
{
// Parse bucket and dir_path
std
::
string
bucket
,
dir_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
dir_path
));
std
::
string
true_path
=
"s3://"
+
bucket
+
'/'
+
dir_path
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
true_path
,
subdirs
));
// Erase non-directory entries...
for
(
auto
iter
=
subdirs
->
begin
();
iter
!=
subdirs
->
end
();)
{
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
true_path
,
*
iter
}),
&
is_dir
));
if
(
!
is_dir
)
{
iter
=
subdirs
->
erase
(
iter
);
}
else
{
++
iter
;
}
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
GetDirectoryFiles
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
files
)
{
// Parse bucket and dir_path
std
::
string
bucket
,
dir_path
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
dir_path
));
std
::
string
true_path
=
"s3://"
+
bucket
+
'/'
+
dir_path
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
true_path
,
files
));
// Erase directory entries...
for
(
auto
iter
=
files
->
begin
();
iter
!=
files
->
end
();)
{
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
JoinPath
({
true_path
,
*
iter
}),
&
is_dir
));
if
(
is_dir
)
{
iter
=
files
->
erase
(
iter
);
}
else
{
++
iter
;
}
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
{
bool
exists
;
RETURN_IF_ERROR
(
FileExists
(
path
,
&
exists
));
if
(
!
exists
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"File does not exist at "
+
path
);
}
std
::
string
bucket
,
object
;
RETURN_IF_ERROR
(
ParsePath
(
path
,
&
bucket
,
&
object
));
// Send a request for the objects metadata
s3
::
Model
::
GetObjectRequest
object_request
;
object_request
.
SetBucket
(
bucket
.
c_str
());
object_request
.
SetKey
(
object
.
c_str
());
auto
get_object_outcome
=
client_
->
GetObject
(
object_request
);
if
(
get_object_outcome
.
IsSuccess
())
{
auto
&
object_result
=
get_object_outcome
.
GetResultWithOwnership
().
GetBody
();
std
::
string
data
=
""
;
char
c
;
while
(
object_result
.
get
(
c
))
{
data
+=
c
;
}
*
contents
=
data
;
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get object at "
+
path
+
" due to exception: "
+
get_object_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
get_object_outcome
.
GetError
().
GetMessage
());
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
{
// Check if the directory or file exists
bool
exists
;
RETURN_IF_ERROR
(
FileExists
(
path
,
&
exists
));
if
(
!
exists
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"directory or file does not exist at "
+
path
);
}
// Cleanup extra slashes
std
::
string
clean_path
;
RETURN_IF_ERROR
(
CleanPath
(
path
,
&
clean_path
));
// Remove protocol and host name and port
std
::
string
effective_path
,
protocol
,
host_name
,
host_port
,
bucket
,
object
;
if
(
RE2
::
FullMatch
(
clean_path
,
s3_regex_
,
&
protocol
,
&
host_name
,
&
host_port
,
&
bucket
,
&
object
))
{
effective_path
=
"s3://"
+
bucket
+
object
;
}
else
{
effective_path
=
path
;
}
// Create temporary directory
std
::
string
tmp_folder
;
RETURN_IF_ERROR
(
triton
::
core
::
MakeTemporaryDirectory
(
FileSystemType
::
LOCAL
,
&
tmp_folder
));
// Specify contents to be downloaded
std
::
set
<
std
::
string
>
contents
;
bool
is_dir
;
RETURN_IF_ERROR
(
IsDirectory
(
path
,
&
is_dir
));
if
(
is_dir
)
{
// Set localized path
localized
->
reset
(
new
LocalizedPath
(
effective_path
,
tmp_folder
));
// Specify the entire directory to be downloaded
std
::
set
<
std
::
string
>
filenames
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
effective_path
,
&
filenames
));
for
(
auto
itr
=
filenames
.
begin
();
itr
!=
filenames
.
end
();
++
itr
)
{
contents
.
insert
(
JoinPath
({
effective_path
,
*
itr
}));
}
}
else
{
// Set localized path
std
::
string
filename
=
effective_path
.
substr
(
effective_path
.
find_last_of
(
'/'
)
+
1
);
localized
->
reset
(
new
LocalizedPath
(
effective_path
,
JoinPath
({
tmp_folder
,
filename
})));
// Specify only the file to be downloaded
contents
.
insert
(
effective_path
);
}
// Download all specified contents and nested contents
while
(
contents
.
size
()
!=
0
)
{
std
::
set
<
std
::
string
>
tmp_contents
=
contents
;
contents
.
clear
();
for
(
auto
iter
=
tmp_contents
.
begin
();
iter
!=
tmp_contents
.
end
();
++
iter
)
{
std
::
string
s3_fpath
=
*
iter
;
std
::
string
s3_removed_path
=
s3_fpath
.
substr
(
effective_path
.
size
());
std
::
string
local_fpath
=
s3_removed_path
.
empty
()
?
(
*
localized
)
->
Path
()
:
JoinPath
({(
*
localized
)
->
Path
(),
s3_removed_path
});
bool
is_subdir
;
RETURN_IF_ERROR
(
IsDirectory
(
s3_fpath
,
&
is_subdir
));
if
(
is_subdir
)
{
// Create local mirror of sub-directories
#ifdef _WIN32
int
status
=
mkdir
(
const_cast
<
char
*>
(
local_fpath
.
c_str
()));
#else
int
status
=
mkdir
(
const_cast
<
char
*>
(
local_fpath
.
c_str
()),
S_IRUSR
|
S_IWUSR
|
S_IXUSR
);
#endif
if
(
status
==
-
1
)
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to create local folder: "
+
local_fpath
+
", errno:"
+
strerror
(
errno
));
}
// Add sub-directories and deeper files to contents
std
::
set
<
std
::
string
>
subdir_contents
;
RETURN_IF_ERROR
(
GetDirectoryContents
(
s3_fpath
,
&
subdir_contents
));
for
(
auto
itr
=
subdir_contents
.
begin
();
itr
!=
subdir_contents
.
end
();
++
itr
)
{
contents
.
insert
(
JoinPath
({
s3_fpath
,
*
itr
}));
}
}
else
{
// Create local copy of file
std
::
string
file_bucket
,
file_object
;
RETURN_IF_ERROR
(
ParsePath
(
s3_fpath
,
&
file_bucket
,
&
file_object
));
s3
::
Model
::
GetObjectRequest
object_request
;
object_request
.
SetBucket
(
file_bucket
.
c_str
());
object_request
.
SetKey
(
file_object
.
c_str
());
auto
get_object_outcome
=
client_
->
GetObject
(
object_request
);
if
(
get_object_outcome
.
IsSuccess
())
{
auto
&
retrieved_file
=
get_object_outcome
.
GetResultWithOwnership
().
GetBody
();
std
::
ofstream
output_file
(
local_fpath
.
c_str
(),
std
::
ios
::
binary
);
output_file
<<
retrieved_file
.
rdbuf
();
output_file
.
close
();
}
else
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Failed to get object at "
+
s3_fpath
+
" due to exception: "
+
get_object_outcome
.
GetError
().
GetExceptionName
()
+
", error message: "
+
get_object_outcome
.
GetError
().
GetMessage
());
}
}
}
}
return
Status
::
Success
;
}
Status
S3FileSystem
::
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Write text file operation not yet implemented "
+
path
);
}
Status
S3FileSystem
::
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Write text file operation not yet implemented "
+
path
);
}
Status
S3FileSystem
::
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Make directory operation not yet implemented"
);
}
Status
S3FileSystem
::
MakeTemporaryDirectory
(
std
::
string
*
temp_dir
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Make temporary directory operation not yet implemented"
);
}
Status
S3FileSystem
::
DeletePath
(
const
std
::
string
&
path
)
{
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Delete path operation not yet implemented"
);
}
#endif // TRITON_ENABLE_S3
class
FileSystemManager
{
public:
Status
GetFileSystem
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
FileSystem
>&
file_system
);
Status
GetFileSystem
(
FileSystemType
type
,
std
::
shared_ptr
<
FileSystem
>&
file_system
);
FileSystemManager
();
private:
template
<
class
CacheType
,
class
CredentialType
,
class
FileSystemType
>
Status
GetFileSystem
(
const
std
::
string
&
path
,
CacheType
&
cache
,
std
::
shared_ptr
<
FileSystem
>&
file_system
);
template
<
class
CacheType
,
class
CredentialType
,
class
FileSystemType
>
Status
ReturnErrorOrReload
(
const
Status
&
load_status
,
const
Status
&
error_status
,
const
std
::
string
&
path
,
CacheType
&
cache
,
std
::
shared_ptr
<
FileSystem
>&
file_system
);
Status
LoadCredentials
(
bool
flush_cache
=
false
);
template
<
class
CacheType
,
class
CredentialType
,
class
FileSystemType
>
static
void
LoadCredential
(
triton
::
common
::
TritonJson
::
Value
&
creds_json
,
const
char
*
fs_type
,
CacheType
&
cache
);
template
<
class
CredentialType
,
class
FileSystemType
>
static
void
SortCache
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
CredentialType
,
std
::
shared_ptr
<
FileSystemType
>>>&
cache
);
template
<
class
CredentialType
,
class
FileSystemType
>
static
Status
GetLongestMatchingNameIndex
(
const
std
::
vector
<
std
::
tuple
<
std
::
string
,
CredentialType
,
std
::
shared_ptr
<
FileSystemType
>>>&
cache
,
const
std
::
string
&
path
,
size_t
&
idx
);
std
::
shared_ptr
<
LocalFileSystem
>
local_fs_
;
std
::
mutex
mu_
;
// protect concurrent access into variables
bool
is_cached_
;
// if name and credential is cached, lazy load file system
// cloud credential cache should be sorted in descending name length order
// [(name_long, credential, file_system), (name, ...)]
#ifdef TRITON_ENABLE_GCS
std
::
vector
<
std
::
tuple
<
std
::
string
,
GCSCredential
,
std
::
shared_ptr
<
GCSFileSystem
>>>
gs_cache_
;
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
std
::
vector
<
std
::
tuple
<
std
::
string
,
S3Credential
,
std
::
shared_ptr
<
S3FileSystem
>>>
s3_cache_
;
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
std
::
vector
<
std
::
tuple
<
std
::
string
,
ASCredential
,
std
::
shared_ptr
<
ASFileSystem
>>>
as_cache_
;
#endif // TRITON_ENABLE_AZURE_STORAGE
};
FileSystemManager
::
FileSystemManager
()
:
local_fs_
(
new
LocalFileSystem
()),
is_cached_
(
false
)
{
}
Status
FileSystemManager
::
GetFileSystem
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
FileSystem
>&
file_system
)
{
// Check if this is a GCS path (gs://$BUCKET_NAME)
if
(
!
path
.
empty
()
&&
!
path
.
rfind
(
"gs://"
,
0
))
{
#ifndef TRITON_ENABLE_GCS
return
Status
(
Status
::
Code
::
INTERNAL
,
"gs:// file-system not supported. To enable, build with "
"-DTRITON_ENABLE_GCS=ON."
);
#else
return
GetFileSystem
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
GCSCredential
,
std
::
shared_ptr
<
GCSFileSystem
>>>
,
GCSCredential
,
GCSFileSystem
>
(
path
,
gs_cache_
,
file_system
);
#endif // TRITON_ENABLE_GCS
}
// Check if this is an S3 path (s3://$BUCKET_NAME)
if
(
!
path
.
empty
()
&&
!
path
.
rfind
(
"s3://"
,
0
))
{
#ifndef TRITON_ENABLE_S3
return
Status
(
Status
::
Code
::
INTERNAL
,
"s3:// file-system not supported. To enable, build with "
"-DTRITON_ENABLE_S3=ON."
);
#else
return
GetFileSystem
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
S3Credential
,
std
::
shared_ptr
<
S3FileSystem
>>>
,
S3Credential
,
S3FileSystem
>
(
path
,
s3_cache_
,
file_system
);
#endif // TRITON_ENABLE_S3
}
// Check if this is an Azure Storage path
if
(
!
path
.
empty
()
&&
!
path
.
rfind
(
"as://"
,
0
))
{
#ifndef TRITON_ENABLE_AZURE_STORAGE
return
Status
(
Status
::
Code
::
INTERNAL
,
"as:// file-system not supported. To enable, build with "
"-DTRITON_ENABLE_AZURE_STORAGE=ON."
);
#else
return
GetFileSystem
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
ASCredential
,
std
::
shared_ptr
<
ASFileSystem
>>>
,
ASCredential
,
ASFileSystem
>
(
path
,
as_cache_
,
file_system
);
#endif // TRITON_ENABLE_AZURE_STORAGE
}
// Assume path is for local filesystem
file_system
=
local_fs_
;
return
Status
::
Success
;
}
Status
FileSystemManager
::
GetFileSystem
(
FileSystemType
type
,
std
::
shared_ptr
<
FileSystem
>&
file_system
)
{
// only LOCAL and GCS are not path-dependent and can be accessed by type
switch
(
type
)
{
case
FileSystemType
::
LOCAL
:
return
GetFileSystem
(
""
,
file_system
);
case
FileSystemType
::
GCS
:
return
GetFileSystem
(
"gs://"
,
file_system
);
case
FileSystemType
::
S3
:
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"S3 filesystem cannot be accessed by type"
);
case
FileSystemType
::
AS
:
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"AS filesystem cannot be accessed by type"
);
default:
return
Status
(
Status
::
Code
::
UNSUPPORTED
,
"Unsupported filesystem type"
);
}
}
template
<
class
CacheType
,
class
CredentialType
,
class
FileSystemType
>
Status
FileSystemManager
::
GetFileSystem
(
const
std
::
string
&
path
,
CacheType
&
cache
,
std
::
shared_ptr
<
FileSystem
>&
file_system
)
{
const
Status
&
cred_status
=
LoadCredentials
();
if
(
cred_status
.
IsOk
()
||
cred_status
.
StatusCode
()
==
Status
::
Code
::
ALREADY_EXISTS
)
{
// Find credential
size_t
idx
;
const
Status
&
match_status
=
GetLongestMatchingNameIndex
(
cache
,
path
,
idx
);
if
(
!
match_status
.
IsOk
())
{
return
ReturnErrorOrReload
<
CacheType
,
CredentialType
,
FileSystemType
>
(
cred_status
,
match_status
,
path
,
cache
,
file_system
);
}
// Find or lazy load file system
std
::
shared_ptr
<
FileSystemType
>
fs
=
std
::
get
<
2
>
(
cache
[
idx
]);
if
(
fs
==
nullptr
)
{
std
::
string
cred_name
=
std
::
get
<
0
>
(
cache
[
idx
]);
CredentialType
cred
=
std
::
get
<
1
>
(
cache
[
idx
]);
fs
=
std
::
make_shared
<
FileSystemType
>
(
path
,
cred
);
cache
[
idx
]
=
std
::
make_tuple
(
cred_name
,
cred
,
fs
);
}
// Check client
const
Status
&
client_status
=
fs
->
CheckClient
(
path
);
if
(
!
client_status
.
IsOk
())
{
return
ReturnErrorOrReload
<
CacheType
,
CredentialType
,
FileSystemType
>
(
cred_status
,
client_status
,
path
,
cache
,
file_system
);
}
// Return client
file_system
=
fs
;
return
Status
::
Success
;
}
return
cred_status
;
}
template
<
class
CacheType
,
class
CredentialType
,
class
FileSystemType
>
Status
FileSystemManager
::
ReturnErrorOrReload
(
const
Status
&
load_status
,
const
Status
&
error_status
,
const
std
::
string
&
path
,
CacheType
&
cache
,
std
::
shared_ptr
<
FileSystem
>&
file_system
)
{
if
(
load_status
.
StatusCode
()
==
Status
::
Code
::
ALREADY_EXISTS
)
{
return
error_status
;
}
LoadCredentials
(
true
);
// flush cache
return
GetFileSystem
<
CacheType
,
CredentialType
,
FileSystemType
>
(
path
,
cache
,
file_system
);
}
// return status meaning:
// - SUCCESS, "" -> loaded credential from file
// - ALREADY_EXISTS, "Cached" -> credential already loaded
Status
FileSystemManager
::
LoadCredentials
(
bool
flush_cache
)
{
// prevent concurrent access into class variables
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
// check if credential is already cached
if
(
is_cached_
&&
!
flush_cache
)
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"Cached"
);
}
const
char
*
file_path_c_str
=
std
::
getenv
(
"TRITON_CLOUD_CREDENTIAL_PATH"
);
if
(
file_path_c_str
!=
nullptr
)
{
// Load from credential file
std
::
string
file_path
=
std
::
string
(
file_path_c_str
);
LOG_VERBOSE
(
1
)
<<
"Reading cloud credential from "
<<
file_path
;
triton
::
common
::
TritonJson
::
Value
creds_json
;
std
::
string
cred_file_content
;
RETURN_IF_ERROR
(
local_fs_
->
ReadTextFile
(
file_path
,
&
cred_file_content
));
RETURN_IF_ERROR
(
creds_json
.
Parse
(
cred_file_content
));
#ifdef TRITON_ENABLE_GCS
// load GCS credentials
LoadCredential
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
GCSCredential
,
std
::
shared_ptr
<
GCSFileSystem
>>>
,
GCSCredential
,
GCSFileSystem
>
(
creds_json
,
"gs"
,
gs_cache_
);
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
// load S3 credentials
LoadCredential
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
S3Credential
,
std
::
shared_ptr
<
S3FileSystem
>>>
,
S3Credential
,
S3FileSystem
>
(
creds_json
,
"s3"
,
s3_cache_
);
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
// load AS credentials
LoadCredential
<
std
::
vector
<
std
::
tuple
<
std
::
string
,
ASCredential
,
std
::
shared_ptr
<
ASFileSystem
>>>
,
ASCredential
,
ASFileSystem
>
(
creds_json
,
"as"
,
as_cache_
);
#endif // TRITON_ENABLE_AZURE_STORAGE
}
else
{
// Load from environment variables
LOG_VERBOSE
(
1
)
<<
"TRITON_CLOUD_CREDENTIAL_PATH environment variable is "
"not set, reading from environment variables"
;
#ifdef TRITON_ENABLE_GCS
// load GCS credentials
gs_cache_
.
clear
();
gs_cache_
.
push_back
(
std
::
make_tuple
(
""
,
GCSCredential
(),
std
::
shared_ptr
<
GCSFileSystem
>
()));
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
// load S3 credentials
s3_cache_
.
clear
();
s3_cache_
.
push_back
(
std
::
make_tuple
(
""
,
S3Credential
(),
std
::
shared_ptr
<
S3FileSystem
>
()));
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
// load AS credentials
as_cache_
.
clear
();
as_cache_
.
push_back
(
std
::
make_tuple
(
""
,
ASCredential
(),
std
::
shared_ptr
<
ASFileSystem
>
()));
#endif // TRITON_ENABLE_AZURE_STORAGE
}
is_cached_
=
true
;
return
Status
::
Success
;
}
template
<
class
CacheType
,
class
CredentialType
,
class
FileSystemType
>
void
FileSystemManager
::
LoadCredential
(
triton
::
common
::
TritonJson
::
Value
&
creds_json
,
const
char
*
fs_type
,
CacheType
&
cache
)
{
cache
.
clear
();
triton
::
common
::
TritonJson
::
Value
creds_fs_json
;
if
(
creds_json
.
Find
(
fs_type
,
&
creds_fs_json
))
{
std
::
vector
<
std
::
string
>
cred_names
;
creds_fs_json
.
Members
(
&
cred_names
);
for
(
size_t
i
=
0
;
i
<
cred_names
.
size
();
i
++
)
{
std
::
string
cred_name
=
cred_names
[
i
];
triton
::
common
::
TritonJson
::
Value
cred_json
;
creds_fs_json
.
Find
(
cred_name
.
c_str
(),
&
cred_json
);
cache
.
push_back
(
std
::
make_tuple
(
cred_name
,
CredentialType
(
cred_json
),
std
::
shared_ptr
<
FileSystemType
>
()));
}
SortCache
(
cache
);
}
}
template
<
class
CredentialType
,
class
FileSystemType
>
void
FileSystemManager
::
SortCache
(
std
::
vector
<
std
::
tuple
<
std
::
string
,
CredentialType
,
std
::
shared_ptr
<
FileSystemType
>>>&
cache
)
{
std
::
sort
(
cache
.
begin
(),
cache
.
end
(),
[](
std
::
tuple
<
std
::
string
,
CredentialType
,
std
::
shared_ptr
<
FileSystemType
>>
a
,
std
::
tuple
<
std
::
string
,
CredentialType
,
std
::
shared_ptr
<
FileSystemType
>>
b
)
{
return
std
::
get
<
0
>
(
a
).
size
()
>=
std
::
get
<
0
>
(
b
).
size
();
});
}
template
<
class
CredentialType
,
class
FileSystemType
>
Status
FileSystemManager
::
GetLongestMatchingNameIndex
(
const
std
::
vector
<
std
::
tuple
<
std
::
string
,
CredentialType
,
std
::
shared_ptr
<
FileSystemType
>>>&
cache
,
const
std
::
string
&
path
,
size_t
&
idx
)
{
for
(
size_t
i
=
0
;
i
<
cache
.
size
();
i
++
)
{
if
(
!
path
.
rfind
(
std
::
get
<
0
>
(
cache
[
i
]),
0
))
{
idx
=
i
;
LOG_VERBOSE
(
1
)
<<
"Using credential "
+
std
::
get
<
0
>
(
cache
[
i
])
+
" for path "
+
path
;
return
Status
::
Success
;
}
}
return
Status
(
Status
::
Code
::
NOT_FOUND
,
"Cannot match credential for path "
+
path
);
}
static
FileSystemManager
fsm_
;
}
// namespace
// FIXME: Windows support '/'? If so, the below doesn't need to change
bool
IsAbsolutePath
(
const
std
::
string
&
path
)
{
return
!
path
.
empty
()
&&
(
path
[
0
]
==
'/'
);
}
std
::
string
JoinPath
(
std
::
initializer_list
<
std
::
string
>
segments
)
{
std
::
string
joined
;
for
(
const
auto
&
seg
:
segments
)
{
if
(
joined
.
empty
())
{
joined
=
seg
;
}
else
if
(
IsAbsolutePath
(
seg
))
{
if
(
joined
[
joined
.
size
()
-
1
]
==
'/'
)
{
joined
.
append
(
seg
.
substr
(
1
));
}
else
{
joined
.
append
(
seg
);
}
}
else
{
// !IsAbsolutePath(seg)
if
(
joined
[
joined
.
size
()
-
1
]
!=
'/'
)
{
joined
.
append
(
"/"
);
}
joined
.
append
(
seg
);
}
}
return
joined
;
}
std
::
string
BaseName
(
const
std
::
string
&
path
)
{
if
(
path
.
empty
())
{
return
path
;
}
size_t
last
=
path
.
size
()
-
1
;
while
((
last
>
0
)
&&
(
path
[
last
]
==
'/'
))
{
last
-=
1
;
}
if
(
path
[
last
]
==
'/'
)
{
return
std
::
string
();
}
const
size_t
idx
=
path
.
find_last_of
(
"/"
,
last
);
if
(
idx
==
std
::
string
::
npos
)
{
return
path
.
substr
(
0
,
last
+
1
);
}
return
path
.
substr
(
idx
+
1
,
last
-
idx
);
}
std
::
string
DirName
(
const
std
::
string
&
path
)
{
if
(
path
.
empty
())
{
return
path
;
}
size_t
last
=
path
.
size
()
-
1
;
while
((
last
>
0
)
&&
(
path
[
last
]
==
'/'
))
{
last
-=
1
;
}
if
(
path
[
last
]
==
'/'
)
{
return
std
::
string
(
"/"
);
}
const
size_t
idx
=
path
.
find_last_of
(
"/"
,
last
);
if
(
idx
==
std
::
string
::
npos
)
{
return
std
::
string
(
"."
);
}
if
(
idx
==
0
)
{
return
std
::
string
(
"/"
);
}
return
path
.
substr
(
0
,
idx
);
}
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
FileExists
(
path
,
exists
);
}
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
IsDirectory
(
path
,
is_dir
);
}
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
FileModificationTime
(
path
,
mtime_ns
);
}
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
GetDirectoryContents
(
path
,
contents
);
}
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
GetDirectorySubdirs
(
path
,
subdirs
);
}
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
const
bool
skip_hidden_files
,
std
::
set
<
std
::
string
>*
files
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
std
::
set
<
std
::
string
>
all_files
;
RETURN_IF_ERROR
(
fs
->
GetDirectoryFiles
(
path
,
&
all_files
));
// Remove the hidden files
for
(
auto
f
:
all_files
)
{
if
((
f
[
0
]
!=
'.'
)
||
(
!
skip_hidden_files
))
{
files
->
insert
(
f
);
}
}
return
Status
::
Success
;
}
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
ReadTextFile
(
path
,
contents
);
}
Status
ReadTextProto
(
const
std
::
string
&
path
,
google
::
protobuf
::
Message
*
msg
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
std
::
string
contents
;
RETURN_IF_ERROR
(
fs
->
ReadTextFile
(
path
,
&
contents
));
if
(
!
google
::
protobuf
::
TextFormat
::
ParseFromString
(
contents
,
msg
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to read text proto from "
+
path
);
}
return
Status
::
Success
;
}
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
LocalizePath
(
path
,
localized
);
}
Status
WriteTextProto
(
const
std
::
string
&
path
,
const
google
::
protobuf
::
Message
&
msg
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
std
::
string
prototxt
;
if
(
!
google
::
protobuf
::
TextFormat
::
PrintToString
(
msg
,
&
prototxt
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"failed to write text proto to "
+
path
);
}
return
fs
->
WriteTextFile
(
path
,
prototxt
);
}
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
WriteBinaryFile
(
path
,
contents
,
content_len
);
}
Status
ReadBinaryProto
(
const
std
::
string
&
path
,
google
::
protobuf
::
MessageLite
*
msg
)
{
std
::
string
msg_str
;
RETURN_IF_ERROR
(
ReadTextFile
(
path
,
&
msg_str
));
google
::
protobuf
::
io
::
CodedInputStream
coded_stream
(
reinterpret_cast
<
const
uint8_t
*>
(
msg_str
.
c_str
()),
msg_str
.
size
());
coded_stream
.
SetTotalBytesLimit
(
INT_MAX
);
if
(
!
msg
->
ParseFromCodedStream
(
&
coded_stream
))
{
return
Status
(
Status
::
Code
::
INTERNAL
,
"Can't parse "
+
path
+
" as binary proto"
);
}
return
Status
::
Success
;
}
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
dir
,
fs
));
return
fs
->
MakeDirectory
(
dir
,
recursive
);
}
Status
MakeTemporaryDirectory
(
const
FileSystemType
type
,
std
::
string
*
temp_dir
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
type
,
fs
));
return
fs
->
MakeTemporaryDirectory
(
temp_dir
);
}
Status
DeletePath
(
const
std
::
string
&
path
)
{
std
::
shared_ptr
<
FileSystem
>
fs
;
RETURN_IF_ERROR
(
fsm_
.
GetFileSystem
(
path
,
fs
));
return
fs
->
DeletePath
(
path
);
}
Status
GetFileSystemType
(
const
std
::
string
&
path
,
FileSystemType
*
type
)
{
if
(
path
.
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"Can not infer filesystem type from empty path"
);
}
#ifdef TRITON_ENABLE_GCS
// Check if this is a GCS path (gs://$BUCKET_NAME)
if
(
!
path
.
rfind
(
"gs://"
,
0
))
{
*
type
=
FileSystemType
::
GCS
;
return
Status
::
Success
;
}
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
// Check if this is an S3 path (s3://$BUCKET_NAME)
if
(
!
path
.
rfind
(
"s3://"
,
0
))
{
*
type
=
FileSystemType
::
S3
;
return
Status
::
Success
;
}
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
// Check if this is an Azure Storage path
if
(
!
path
.
rfind
(
"as://"
,
0
))
{
*
type
=
FileSystemType
::
AS
;
return
Status
::
Success
;
}
#endif // TRITON_ENABLE_AZURE_STORAGE
// Assume path is for local filesystem
*
type
=
FileSystemType
::
LOCAL
;
return
Status
::
Success
;
}
const
std
::
string
&
FileSystemTypeString
(
const
FileSystemType
type
)
{
static
const
std
::
string
local_str
(
"LOCAL"
);
static
const
std
::
string
gcs_str
(
"GCS"
);
static
const
std
::
string
s3_str
(
"S3"
);
static
const
std
::
string
as_str
(
"AS"
);
static
const
std
::
string
unknown_str
(
"UNKNOWN"
);
switch
(
type
)
{
case
FileSystemType
::
LOCAL
:
return
local_str
;
case
FileSystemType
::
GCS
:
return
gcs_str
;
case
FileSystemType
::
S3
:
return
s3_str
;
case
FileSystemType
::
AS
:
return
as_str
;
default:
return
unknown_str
;
}
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/filesystem.h
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef _WIN32
// Remove GetObject definition from windows.h, which can cause
// a naming collision when GetObject is called.
// https://github.com/Tencent/rapidjson/issues/1448
#undef GetObject
#endif // _WIN32
#include <string>
#include "google/protobuf/message.h"
#include "status.h"
namespace
triton
{
namespace
core
{
enum
class
FileSystemType
{
LOCAL
,
GCS
,
S3
,
AS
};
// This class stores the paths of local temporary files needed for loading
// models from Cloud repositories and performs necessary cleanup after the
// models are loaded.
class
LocalizedPath
{
public:
// Create an object for a path that is already local.
LocalizedPath
(
const
std
::
string
&
original_path
)
:
original_path_
(
original_path
)
{
}
// Create an object for a remote path. Store both the original path and the
// temporary local path.
LocalizedPath
(
const
std
::
string
&
original_path
,
const
std
::
string
&
local_path
)
:
original_path_
(
original_path
),
local_path_
(
local_path
)
{
}
// Destructor. Remove temporary local storage associated with the object.
// If the local path is a directory, delete the directory.
// If the local path is a file, delete the directory containing the file.
~
LocalizedPath
();
// Return the localized path represented by this object.
const
std
::
string
&
Path
()
const
{
return
(
local_path_
.
empty
())
?
original_path_
:
local_path_
;
}
// Maintain a vector of LocalizedPath that should be kept available in the
// tmp directory for the lifetime of this object
// FIXME: Remove when no longer required
std
::
vector
<
std
::
shared_ptr
<
LocalizedPath
>>
other_localized_path
;
private:
std
::
string
original_path_
;
std
::
string
local_path_
;
};
/// Is a path an absolute path?
/// \param path The path.
/// \return true if absolute path, false if relative path.
bool
IsAbsolutePath
(
const
std
::
string
&
path
);
/// Join path segments into a longer path
/// \param segments The path segments.
/// \return the path formed by joining the segments.
std
::
string
JoinPath
(
std
::
initializer_list
<
std
::
string
>
segments
);
/// Get the basename of a path.
/// \param path The path.
/// \return the last segment of the path.
std
::
string
BaseName
(
const
std
::
string
&
path
);
/// Get the dirname of a path.
/// \param path The path.
/// \return all but the last segment of the path.
std
::
string
DirName
(
const
std
::
string
&
path
);
/// Does a file or directory exist?
/// \param path The path to check for existance.
/// \param exists Returns true if file/dir exists
/// \return Error status if unable to perform the check
Status
FileExists
(
const
std
::
string
&
path
,
bool
*
exists
);
/// Is a path a directory?
/// \param path The path to check.
/// \param is_dir Returns true if path represents a directory
/// \return Error status
Status
IsDirectory
(
const
std
::
string
&
path
,
bool
*
is_dir
);
/// Get file modification time in nanoseconds.
/// A file is considered modified in Triton when its binary content has changed
/// including the action of replacing it with another file.
/// \param path The path.
/// \param mtime_ns Returns the file modification time. For some filesystems a
/// file/folder may not have a modification time, in that case return 0.
/// \return Error status
Status
FileModificationTime
(
const
std
::
string
&
path
,
int64_t
*
mtime_ns
);
/// Get the contents of a directory.
/// \param path The directory path.
/// \param subdirs Returns the directory contents.
/// \return Error status
Status
GetDirectoryContents
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
contents
);
/// Get the sub-directories of a path.
/// \param path The path.
/// \param subdirs Returns the names of the sub-directories.
/// \return Error status
Status
GetDirectorySubdirs
(
const
std
::
string
&
path
,
std
::
set
<
std
::
string
>*
subdirs
);
/// Get the files contained in a directory.
/// \param path The directory.
/// \param skip_hidden_files Ignores the hidden files in the directory.
/// \param files Returns the names of the files.
/// \return Error status
Status
GetDirectoryFiles
(
const
std
::
string
&
path
,
const
bool
skip_hidden_files
,
std
::
set
<
std
::
string
>*
files
);
/// Read a text file into a string.
/// \param path The path of the file.
/// \param contents Returns the contents of the file.
/// \return Error status
Status
ReadTextFile
(
const
std
::
string
&
path
,
std
::
string
*
contents
);
/// Create an object representing a local copy of a path.
/// \param path The path of the directory or file.
/// \param localized Returns the LocalizedPath object
/// representing the local copy of the path.
/// \return Error status
Status
LocalizePath
(
const
std
::
string
&
path
,
std
::
shared_ptr
<
LocalizedPath
>*
localized
);
/// Write a string to a file.
/// \param path The path of the file.
/// \param contents The contents to write to the file.
/// \return Error status
Status
WriteTextFile
(
const
std
::
string
&
path
,
const
std
::
string
&
contents
);
/// Write binary to a file.
/// \param path The path of the file.
/// \param contents The contents to write to the file.
/// \param content_len The size of the content.
/// \return Error status
Status
WriteBinaryFile
(
const
std
::
string
&
path
,
const
char
*
contents
,
const
size_t
content_len
);
/// Read a prototext file.
/// \param path The path of the file.
/// \param msg Returns the protobuf message for the file.
/// \return Error status
Status
ReadTextProto
(
const
std
::
string
&
path
,
google
::
protobuf
::
Message
*
msg
);
/// Write a prototext file.
/// \param path The path of the file.
/// \param msg The protobuf to write.
/// \return Error status
Status
WriteTextProto
(
const
std
::
string
&
path
,
const
google
::
protobuf
::
Message
&
msg
);
/// Read a binary protobuf file.
/// \param path The path of the file.
/// \param msg Returns the protobuf message for the file.
/// \return Error status
Status
ReadBinaryProto
(
const
std
::
string
&
path
,
google
::
protobuf
::
MessageLite
*
msg
);
/// Create a directory of the specified path.
/// \param dir The path to the directory.
/// \param recursive Whether the parent directories will be created
/// if not exist.
/// \return Error status if the directory can't be created
Status
MakeDirectory
(
const
std
::
string
&
dir
,
const
bool
recursive
);
/// Create a temporary directory of the specified filesystem type.
/// \param type The type of the filesystem.
/// \param temp_dir Returns the path to the temporary directory.
/// \return Error status
Status
MakeTemporaryDirectory
(
const
FileSystemType
type
,
std
::
string
*
temp_dir
);
/// Delete a path.
/// \param path The path to the directory or file.
/// \return Error status
Status
DeletePath
(
const
std
::
string
&
path
);
/// Infer the filesystem type from the given path.
/// \param path The path to infer the filesystem type from.
/// \param type Returns the filesystem type of the path.
/// \return Error status
Status
GetFileSystemType
(
const
std
::
string
&
path
,
FileSystemType
*
type
);
/// Return the string representation of the filesystem type.
/// \param type The filesystem type.
/// \return The string representation of the type.
const
std
::
string
&
FileSystemTypeString
(
const
FileSystemType
type
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_parameter.cc
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_parameter.h"
namespace
triton
{
namespace
core
{
const
void
*
InferenceParameter
::
ValuePointer
()
const
{
switch
(
type_
)
{
case
TRITONSERVER_PARAMETER_STRING
:
return
reinterpret_cast
<
const
void
*>
(
value_string_
.
c_str
());
case
TRITONSERVER_PARAMETER_INT
:
return
reinterpret_cast
<
const
void
*>
(
&
value_int64_
);
case
TRITONSERVER_PARAMETER_BOOL
:
return
reinterpret_cast
<
const
void
*>
(
&
value_bool_
);
case
TRITONSERVER_PARAMETER_BYTES
:
return
reinterpret_cast
<
const
void
*>
(
value_bytes_
);
default:
break
;
}
return
nullptr
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceParameter
&
parameter
)
{
out
<<
"[0x"
<<
std
::
addressof
(
parameter
)
<<
"] "
<<
"name: "
<<
parameter
.
Name
()
<<
", type: "
<<
TRITONSERVER_ParameterTypeString
(
parameter
.
Type
())
<<
", value: "
;
return
out
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_parameter.h
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <iostream>
#include <string>
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
//
// An inference parameter.
//
class
InferenceParameter
{
public:
InferenceParameter
(
const
char
*
name
,
const
char
*
value
)
:
name_
(
name
),
type_
(
TRITONSERVER_PARAMETER_STRING
),
value_string_
(
value
)
{
byte_size_
=
value_string_
.
size
();
}
InferenceParameter
(
const
char
*
name
,
const
int64_t
value
)
:
name_
(
name
),
type_
(
TRITONSERVER_PARAMETER_INT
),
value_int64_
(
value
),
byte_size_
(
sizeof
(
int64_t
))
{
}
InferenceParameter
(
const
char
*
name
,
const
bool
value
)
:
name_
(
name
),
type_
(
TRITONSERVER_PARAMETER_BOOL
),
value_bool_
(
value
),
byte_size_
(
sizeof
(
bool
))
{
}
InferenceParameter
(
const
char
*
name
,
const
void
*
ptr
,
const
uint64_t
size
)
:
name_
(
name
),
type_
(
TRITONSERVER_PARAMETER_BYTES
),
value_bytes_
(
ptr
),
byte_size_
(
size
)
{
}
// The name of the parameter.
const
std
::
string
&
Name
()
const
{
return
name_
;
}
// Data type of the parameter.
TRITONSERVER_ParameterType
Type
()
const
{
return
type_
;
}
// Return a pointer to the parameter, or a pointer to the data content
// if type_ is TRITONSERVER_PARAMETER_BYTES. This returned pointer must be
// cast correctly based on 'type_'.
// TRITONSERVER_PARAMETER_STRING -> const char*
// TRITONSERVER_PARAMETER_INT -> int64_t*
// TRITONSERVER_PARAMETER_BOOL -> bool*
// TRITONSERVER_PARAMETER_BYTES -> const void*
const
void
*
ValuePointer
()
const
;
// Return the data byte size of the parameter.
uint64_t
ValueByteSize
()
const
{
return
byte_size_
;
}
// Return the parameter value string, the return value is valid only if
// Type() returns TRITONSERVER_PARAMETER_STRING
const
std
::
string
&
ValueString
()
const
{
return
value_string_
;
}
private:
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceParameter
&
parameter
);
std
::
string
name_
;
TRITONSERVER_ParameterType
type_
;
std
::
string
value_string_
;
int64_t
value_int64_
;
bool
value_bool_
;
const
void
*
value_bytes_
;
uint64_t
byte_size_
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceParameter
&
parameter
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_request.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_request.h"
#include <algorithm>
#include <deque>
#include "model.h"
#include "model_config_utils.h"
#include "server.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_TRACING
#include "cuda_utils.h"
#endif // TRITON_ENABLE_TRACING
namespace
triton
{
namespace
core
{
namespace
{
// Utilities for Null request feature.
TRITONSERVER_Error
*
NullResponseAlloc
(
TRITONSERVER_ResponseAllocator
*
allocator
,
const
char
*
tensor_name
,
size_t
byte_size
,
TRITONSERVER_MemoryType
preferred_memory_type
,
int64_t
preferred_memory_type_id
,
void
*
userp
,
void
**
buffer
,
void
**
buffer_userp
,
TRITONSERVER_MemoryType
*
actual_memory_type
,
int64_t
*
actual_memory_type_id
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INTERNAL
,
"unexpected allocation for null request, no output should be requested."
);
}
TRITONSERVER_Error
*
NullResponseRelease
(
TRITONSERVER_ResponseAllocator
*
allocator
,
void
*
buffer
,
void
*
buffer_userp
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
return
TRITONSERVER_ErrorNew
(
TRITONSERVER_ERROR_INTERNAL
,
"unexpected release for null request, no output should be requested."
);
}
ResponseAllocator
null_allocator
=
ResponseAllocator
(
NullResponseAlloc
,
NullResponseRelease
,
nullptr
/* start_fn */
);
void
NullResponseComplete
(
TRITONSERVER_InferenceResponse
*
iresponse
,
const
uint32_t
flags
,
void
*
userp
)
{
if
(
iresponse
!=
nullptr
)
{
LOG_TRITONSERVER_ERROR
(
TRITONSERVER_InferenceResponseDelete
(
iresponse
),
"deleting null response"
);
}
}
void
NullRequestComplete
(
TRITONSERVER_InferenceRequest
*
request
,
const
uint32_t
flags
,
void
*
userp
)
{
if
((
flags
&
TRITONSERVER_REQUEST_RELEASE_ALL
)
!=
0
)
{
LOG_TRITONSERVER_ERROR
(
TRITONSERVER_InferenceRequestDelete
(
request
),
"deleting null request"
);
}
}
}
// namespace
InferenceRequest
::
InferenceRequest
(
const
std
::
shared_ptr
<
Model
>&
model
,
const
int64_t
requested_model_version
)
:
InferenceRequest
(
model
.
get
(),
requested_model_version
)
{
model_shared_
=
model
;
}
InferenceRequest
::
InferenceRequest
(
Model
*
model
,
const
int64_t
requested_model_version
)
:
needs_normalization_
(
true
),
model_raw_
(
model
),
requested_model_version_
(
requested_model_version
),
flags_
(
0
),
correlation_id_
(
0
),
batch_size_
(
0
),
timeout_us_
(
0
),
collect_stats_
(
true
)
{
SetPriority
(
0
);
}
const
std
::
string
&
InferenceRequest
::
ModelName
()
const
{
return
model_raw_
->
Name
();
}
int64_t
InferenceRequest
::
ActualModelVersion
()
const
{
return
model_raw_
->
Version
();
}
void
InferenceRequest
::
SetPriority
(
uint32_t
p
)
{
if
((
p
==
0
)
||
(
p
>
model_raw_
->
MaxPriorityLevel
()))
{
priority_
=
model_raw_
->
DefaultPriorityLevel
();
}
else
{
priority_
=
p
;
}
}
#ifdef TRITON_ENABLE_TRACING
Status
InferenceRequest
::
TraceInputTensors
(
TRITONSERVER_InferenceTraceActivity
activity
,
const
std
::
string
&
msg
)
{
const
auto
&
inputs
=
this
->
ImmutableInputs
();
TRITONSERVER_MemoryType
dst_memory_type
=
TRITONSERVER_MEMORY_CPU
;
int64_t
dst_memory_type_id
=
0
;
for
(
const
auto
&
pr
:
inputs
)
{
InferenceRequest
::
Input
*
ti
=
pr
.
second
;
// input data
const
std
::
string
&
name
=
ti
->
Name
();
TRITONSERVER_DataType
datatype
=
DataTypeToTriton
(
ti
->
DType
());
uint64_t
byte_size
=
ti
->
Data
()
->
TotalByteSize
();
const
int64_t
*
shape
=
ti
->
ShapeWithBatchDim
().
data
();
uint32_t
dim_count
=
ti
->
ShapeWithBatchDim
().
size
();
uint32_t
buffer_count
=
ti
->
DataBufferCount
();
// chunk buffer
Status
status
;
const
void
*
buffer
;
uint64_t
buffer_size
;
TRITONSERVER_MemoryType
src_memory_type
;
int64_t
src_memory_type_id
;
bool
cuda_used
;
if
(
buffer_count
==
0
)
{
LOG_STATUS_ERROR
(
status
,
LogRequest
()
+
TRITONSERVER_InferenceTraceActivityString
(
activity
)
+
": "
+
msg
+
": tensor: "
+
name
+
": no buffer chunk"
);
continue
;
}
if
(
buffer_count
==
1
)
{
status
=
ti
->
DataBuffer
(
0
,
&
buffer
,
&
buffer_size
,
&
src_memory_type
,
&
src_memory_type_id
);
if
(
!
status
.
IsOk
())
{
LOG_STATUS_ERROR
(
status
,
LogRequest
()
+
TRITONSERVER_InferenceTraceActivityString
(
activity
)
+
": "
+
msg
+
": tensor: "
+
name
+
": fail to get data buffer: "
+
status
.
Message
());
return
status
;
}
if
(
buffer_size
!=
byte_size
)
{
LOG_STATUS_ERROR
(
status
,
LogRequest
()
+
TRITONSERVER_InferenceTraceActivityString
(
activity
)
+
": "
+
msg
+
": tensor: "
+
name
+
": truncated buffer"
);
continue
;
}
INFER_TRACE_TENSOR_ACTIVITY
(
this
->
trace_
,
activity
,
name
.
c_str
(),
datatype
,
const_cast
<
void
*>
(
buffer
),
buffer_size
,
shape
,
dim_count
,
src_memory_type
,
src_memory_type_id
);
continue
;
}
// input buffer
std
::
vector
<
char
>
in_buffer
(
byte_size
);
char
*
base
=
in_buffer
.
data
();
size_t
offset
=
0
;
for
(
uint32_t
b
=
0
;
b
<
buffer_count
;
++
b
)
{
status
=
ti
->
DataBuffer
(
b
,
&
buffer
,
&
buffer_size
,
&
src_memory_type
,
&
src_memory_type_id
);
if
(
!
status
.
IsOk
())
{
LOG_STATUS_ERROR
(
status
,
LogRequest
()
+
TRITONSERVER_InferenceTraceActivityString
(
activity
)
+
": "
+
msg
+
": tensor: "
+
name
+
": fail to get data buffer: "
+
status
.
Message
());
return
status
;
}
status
=
CopyBuffer
(
"InferenceRequest TraceInputTensors"
,
src_memory_type
,
src_memory_type_id
,
dst_memory_type
,
dst_memory_type_id
,
buffer_size
,
buffer
,
base
+
offset
,
nullptr
,
&
cuda_used
);
if
(
!
status
.
IsOk
())
{
LOG_STATUS_ERROR
(
status
,
LogRequest
()
+
TRITONSERVER_InferenceTraceActivityString
(
activity
)
+
": "
+
msg
+
": tensor: "
+
name
+
": fail to copy buffer: "
+
status
.
Message
());
return
status
;
}
offset
+=
buffer_size
;
}
INFER_TRACE_TENSOR_ACTIVITY
(
this
->
trace_
,
activity
,
name
.
c_str
(),
datatype
,
static_cast
<
void
*>
(
base
),
byte_size
,
shape
,
dim_count
,
dst_memory_type
,
dst_memory_type_id
);
}
return
Status
::
Success
;
}
#endif // TRITON_ENABLE_TRACING
Status
InferenceRequest
::
OutputBufferProperties
(
const
char
*
name
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
{
const
auto
allocator
=
response_factory_
->
Allocator
();
if
((
allocator
==
nullptr
)
||
(
allocator
->
QueryFn
()
==
nullptr
))
{
return
Status
(
Status
::
Code
::
UNAVAILABLE
,
(
LogRequest
()
+
"Output properties are not available"
).
c_str
());
}
else
{
RETURN_IF_TRITONSERVER_ERROR
(
allocator
->
QueryFn
()(
reinterpret_cast
<
TRITONSERVER_ResponseAllocator
*>
(
const_cast
<
ResponseAllocator
*>
(
allocator
)),
response_factory_
->
AllocatorUserp
(),
name
,
byte_size
,
memory_type
,
memory_type_id
));
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
Run
(
std
::
unique_ptr
<
InferenceRequest
>&
request
)
{
return
request
->
model_raw_
->
Enqueue
(
request
);
}
void
InferenceRequest
::
RespondIfError
(
std
::
unique_ptr
<
InferenceRequest
>&
request
,
const
Status
&
status
,
const
bool
release_request
)
{
if
(
status
.
IsOk
())
{
return
;
}
// Use the response factory to create a response, set the status,
// and send it. If something goes wrong all we can do is log the
// error. Because this is sending an error we assume that this is
// the last response for the request and so set the FINAL flag.
std
::
unique_ptr
<
InferenceResponse
>
response
;
LOG_STATUS_ERROR
(
request
->
response_factory_
->
CreateResponse
(
&
response
),
(
request
->
LogRequest
()
+
"failed to create error response"
).
c_str
());
LOG_STATUS_ERROR
(
InferenceResponse
::
SendWithStatus
(
std
::
move
(
response
),
TRITONSERVER_RESPONSE_COMPLETE_FINAL
,
status
),
(
request
->
LogRequest
()
+
"failed to send error response"
).
c_str
());
// If releasing the request then invoke the release callback which
// gives ownership to the callback. So can't access 'request' after
// this point.
if
(
release_request
)
{
InferenceRequest
::
Release
(
std
::
move
(
request
),
TRITONSERVER_REQUEST_RELEASE_ALL
);
}
}
void
InferenceRequest
::
RespondIfError
(
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>&
requests
,
const
Status
&
status
,
const
bool
release_requests
)
{
if
(
status
.
IsOk
())
{
return
;
}
for
(
auto
&
request
:
requests
)
{
RespondIfError
(
request
,
status
,
release_requests
);
}
}
void
InferenceRequest
::
Release
(
std
::
unique_ptr
<
InferenceRequest
>&&
request
,
const
uint32_t
release_flags
)
{
// Invoke the release callbacks added internally before releasing the
// request to user provided callback.
for
(
auto
it
=
request
->
release_callbacks_
.
rbegin
();
it
!=
request
->
release_callbacks_
.
rend
();
it
++
)
{
(
*
it
)();
}
request
->
release_callbacks_
.
clear
();
#ifdef TRITON_ENABLE_TRACING
// If tracing then record request end and release the trace.
// This must be before the request callback to ensure the trace
// is properly layered, as the request may be nested in an ensemble
// and the callback may interact with upper level trace.
if
(
request
->
trace_
!=
nullptr
)
{
request
->
trace_
->
ReportNow
(
TRITONSERVER_TRACE_REQUEST_END
);
request
->
ReleaseTrace
();
}
#endif // TRITON_ENABLE_TRACING
void
*
userp
=
request
->
release_userp_
;
auto
&
release_fn
=
request
->
release_fn_
;
release_fn
(
reinterpret_cast
<
TRITONSERVER_InferenceRequest
*>
(
request
.
release
()),
release_flags
,
userp
);
}
InferenceRequest
*
InferenceRequest
::
CopyAsNull
(
const
InferenceRequest
&
from
)
{
// Create a copy of 'from' request with artifical inputs and no requested
// outputs. Maybe more efficient to share inputs and other metadata,
// but that binds the Null request with 'from' request's lifecycle.
std
::
unique_ptr
<
InferenceRequest
>
lrequest
(
new
InferenceRequest
(
from
.
model_raw_
,
from
.
requested_model_version_
));
lrequest
->
needs_normalization_
=
false
;
lrequest
->
batch_size_
=
from
.
batch_size_
;
lrequest
->
collect_stats_
=
false
;
// Three passes: first to construct input for the shape tensors inputs, second
// to obtain the max input byte size for allocating a large enough buffer for
// all non shape tensor inputs; third to construct the inputs for these
// tensors.
// First pass
for
(
const
auto
&
input
:
from
.
OriginalInputs
())
{
// Handle only shape tensors in this pass
if
(
!
input
.
second
.
IsShapeTensor
())
{
continue
;
}
// Prepare the memory to hold input data
size_t
byte_size
=
input
.
second
.
Data
()
->
TotalByteSize
();
auto
mem_type
=
TRITONSERVER_MEMORY_CPU
;
int64_t
mem_id
=
0
;
std
::
shared_ptr
<
MutableMemory
>
data
=
std
::
make_shared
<
AllocatedMemory
>
(
byte_size
,
mem_type
,
mem_id
);
// Get the source buffer. Assumes shape tensors be in a single buffer on the
// CPU
const
auto
&
from_data
=
input
.
second
.
Data
();
size_t
from_data_byte_size
;
TRITONSERVER_MemoryType
from_data_memory_type
;
int64_t
from_data_memory_id
;
const
char
*
from_data_buffer
=
from_data
->
BufferAt
(
0
/* idx */
,
&
from_data_byte_size
,
&
from_data_memory_type
,
&
from_data_memory_id
);
if
(
from_data_byte_size
!=
byte_size
)
{
LOG_WARNING
<<
lrequest
->
LogRequest
()
<<
"The byte size of shape tensor to be copied does not match"
;
}
// Copy the shape values to the input buffer
std
::
memcpy
(
data
->
MutableBuffer
(),
from_data_buffer
,
from_data_byte_size
);
Input
*
new_input
;
lrequest
->
AddOriginalInput
(
input
.
first
,
input
.
second
.
DType
(),
input
.
second
.
Shape
(),
&
new_input
);
// Must normalize shape here...
*
new_input
->
MutableShape
()
=
input
.
second
.
Shape
();
*
new_input
->
MutableShapeWithBatchDim
()
=
input
.
second
.
ShapeWithBatchDim
();
new_input
->
SetData
(
data
);
}
// Second pass
size_t
max_byte_size
=
0
;
size_t
max_str_byte_size
=
0
;
const
std
::
string
*
max_input_name
;
for
(
const
auto
&
input
:
from
.
OriginalInputs
())
{
// Skip shape tensors in this pass
if
(
input
.
second
.
IsShapeTensor
())
{
continue
;
}
if
(
input
.
second
.
DType
()
==
inference
::
DataType
::
TYPE_STRING
)
{
int64_t
element_count
=
triton
::
common
::
GetElementCount
(
input
.
second
.
Shape
());
size_t
str_byte_size
=
static_cast
<
size_t
>
(
4
*
element_count
);
max_str_byte_size
=
std
::
max
(
str_byte_size
,
max_str_byte_size
);
if
(
str_byte_size
>
max_byte_size
)
{
max_byte_size
=
str_byte_size
;
max_input_name
=
&
(
input
.
first
);
}
}
else
{
if
(
input
.
second
.
Data
()
->
TotalByteSize
()
>=
max_byte_size
)
{
max_byte_size
=
input
.
second
.
Data
()
->
TotalByteSize
();
max_input_name
=
&
(
input
.
first
);
}
}
}
// Third pass
// [DLIS-1268] should use one growable static buffer for all null requests
auto
mem_type
=
TRITONSERVER_MEMORY_CPU
;
int64_t
mem_id
=
0
;
std
::
shared_ptr
<
MutableMemory
>
data
=
std
::
make_shared
<
AllocatedMemory
>
(
max_byte_size
,
mem_type
,
mem_id
);
auto
data_base
=
data
->
BufferAt
(
0
,
&
max_byte_size
,
&
mem_type
,
&
mem_id
);
// Zero initialization is only required when there is a TYPE_BYTES tensor in
// the request. Only set the required number of bytes to zero.
if
(
max_str_byte_size
>
0
)
{
std
::
fill
(
data
->
MutableBuffer
(),
data
->
MutableBuffer
()
+
max_str_byte_size
,
0
);
}
for
(
const
auto
&
input
:
from
.
OriginalInputs
())
{
// skip shape tensors in this pass
if
(
input
.
second
.
IsShapeTensor
())
{
continue
;
}
Input
*
new_input
;
lrequest
->
AddOriginalInput
(
input
.
first
,
input
.
second
.
DType
(),
input
.
second
.
Shape
(),
&
new_input
);
// Must normalize shape here...
*
new_input
->
MutableShape
()
=
input
.
second
.
Shape
();
*
new_input
->
MutableShapeWithBatchDim
()
=
input
.
second
.
ShapeWithBatchDim
();
// Note that the input that have max byte size will be responsible for
// holding the artifical data, while other inputs will hold a reference to
// it with byte size that matches 'from'
if
(
input
.
first
==
*
max_input_name
)
{
new_input
->
SetData
(
data
);
}
else
{
if
(
inference
::
DataType
::
TYPE_STRING
==
input
.
second
.
DType
())
{
new_input
->
AppendData
(
data_base
,
triton
::
common
::
GetElementCount
(
input
.
second
.
Shape
())
*
4
,
mem_type
,
mem_id
);
}
else
{
new_input
->
AppendData
(
data_base
,
input
.
second
.
Data
()
->
TotalByteSize
(),
mem_type
,
mem_id
);
}
}
}
// No outputs were requested and thus there should be no allocations.
lrequest
->
SetResponseCallback
(
&
null_allocator
,
nullptr
,
NullResponseComplete
,
nullptr
);
lrequest
->
SetReleaseCallback
(
NullRequestComplete
,
nullptr
);
// Must normalize inputs here...
for
(
auto
&
pr
:
lrequest
->
original_inputs_
)
{
lrequest
->
inputs_
.
emplace
(
std
::
make_pair
(
pr
.
second
.
Name
(),
std
::
addressof
(
pr
.
second
)));
}
return
lrequest
.
release
();
}
Status
InferenceRequest
::
MutableOriginalInput
(
const
std
::
string
&
name
,
InferenceRequest
::
Input
**
input
)
{
auto
itr
=
original_inputs_
.
find
(
name
);
if
(
itr
==
original_inputs_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
name
+
"' does not exist in request"
);
}
*
input
=
&
(
itr
->
second
);
return
Status
::
Success
;
}
Status
InferenceRequest
::
ImmutableInput
(
const
std
::
string
&
name
,
const
InferenceRequest
::
Input
**
input
)
const
{
auto
itr
=
inputs_
.
find
(
name
);
if
(
itr
==
inputs_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
name
+
"' does not exist in request"
);
}
*
input
=
itr
->
second
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
AddOriginalInput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
int64_t
*
shape
,
const
uint64_t
dim_count
,
InferenceRequest
::
Input
**
input
)
{
const
auto
&
pr
=
original_inputs_
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
name
),
std
::
forward_as_tuple
(
name
,
datatype
,
shape
,
dim_count
));
if
(
!
pr
.
second
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
name
+
"' already exists in request"
);
}
if
(
input
!=
nullptr
)
{
*
input
=
std
::
addressof
(
pr
.
first
->
second
);
}
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
AddOriginalInput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
,
InferenceRequest
::
Input
**
input
)
{
return
AddOriginalInput
(
name
,
datatype
,
&
shape
[
0
],
shape
.
size
(),
input
);
}
Status
InferenceRequest
::
AddRawInput
(
const
std
::
string
&
name
,
InferenceRequest
::
Input
**
input
)
{
if
(
original_inputs_
.
size
()
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"raw input '"
+
name
+
"' can't be added to request with other inputs"
);
}
const
auto
&
pr
=
original_inputs_
.
emplace
(
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
name
),
std
::
forward_as_tuple
());
if
(
!
pr
.
second
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
name
+
"' already exists in request"
);
}
if
(
input
!=
nullptr
)
{
*
input
=
std
::
addressof
(
pr
.
first
->
second
);
}
raw_input_name_
=
name
;
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
RemoveOriginalInput
(
const
std
::
string
&
name
)
{
if
(
original_inputs_
.
erase
(
name
)
!=
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
name
+
"' does not exist in request"
);
}
if
(
name
==
raw_input_name_
)
{
raw_input_name_
.
clear
();
}
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
RemoveAllOriginalInputs
()
{
original_inputs_
.
clear
();
raw_input_name_
.
clear
();
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
AddOverrideInput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
int64_t
batch_size
,
const
std
::
vector
<
int64_t
>&
shape
,
std
::
shared_ptr
<
InferenceRequest
::
Input
>*
input
)
{
std
::
shared_ptr
<
Input
>
i
=
std
::
make_shared
<
Input
>
(
name
,
datatype
,
shape
);
*
(
i
->
MutableShape
())
=
i
->
OriginalShape
();
if
(
batch_size
>
0
)
{
*
(
i
->
MutableShapeWithBatchDim
())
=
{
batch_size
};
i
->
MutableShapeWithBatchDim
()
->
insert
(
i
->
MutableShapeWithBatchDim
()
->
end
(),
i
->
OriginalShape
().
begin
(),
i
->
OriginalShape
().
end
());
}
else
{
*
(
i
->
MutableShapeWithBatchDim
())
=
i
->
OriginalShape
();
}
RETURN_IF_ERROR
(
AddOverrideInput
(
i
));
if
(
input
!=
nullptr
)
{
*
input
=
std
::
move
(
i
);
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
AddOverrideInput
(
const
std
::
shared_ptr
<
InferenceRequest
::
Input
>&
input
)
{
LOG_VERBOSE
(
1
)
<<
LogRequest
()
<<
"adding input override for "
<<
input
->
Name
()
<<
": "
<<
*
this
;
const
auto
&
pr
=
override_inputs_
.
emplace
(
std
::
make_pair
(
input
->
Name
(),
input
));
if
(
!
pr
.
second
)
{
pr
.
first
->
second
=
input
;
}
// Add or replace this override in the inputs...
const
auto
res
=
inputs_
.
emplace
(
std
::
make_pair
(
input
->
Name
(),
input
.
get
()));
if
(
!
res
.
second
)
{
res
.
first
->
second
=
input
.
get
();
}
LOG_VERBOSE
(
1
)
<<
LogRequest
()
<<
"added input override for "
<<
input
->
Name
()
<<
": "
<<
*
this
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
AddOriginalRequestedOutput
(
const
std
::
string
&
name
)
{
original_requested_outputs_
.
insert
(
name
);
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
LoadInputStates
()
{
// Add the input states to the inference request.
if
(
sequence_states_
!=
nullptr
)
{
if
(
sequence_states_
->
IsNullRequest
())
{
sequence_states_
=
SequenceStates
::
CopyAsNull
(
sequence_states_
->
NullSequenceStates
());
}
for
(
auto
&
input_state_pair
:
sequence_states_
->
InputStates
())
{
auto
&
input_state
=
input_state_pair
.
second
;
std
::
shared_ptr
<
InferenceRequest
::
Input
>
input
=
std
::
make_shared
<
InferenceRequest
::
Input
>
(
input_state
->
Name
(),
input_state
->
DType
(),
input_state
->
Shape
());
*
input
->
MutableShapeWithBatchDim
()
=
input_state
->
Shape
();
input
->
SetData
(
input_state
->
Data
());
AddOverrideInput
(
input
);
}
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
RemoveOriginalRequestedOutput
(
const
std
::
string
&
name
)
{
original_requested_outputs_
.
erase
(
name
);
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
RemoveAllOriginalRequestedOutputs
()
{
original_requested_outputs_
.
clear
();
needs_normalization_
=
true
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
PrepareForInference
()
{
// Remove override inputs as those are added during any previous
// inference execution.
inputs_
.
clear
();
override_inputs_
.
clear
();
// Renormalize if anything has changed in the inference request in a
// way that could impact renormalization.
if
(
needs_normalization_
)
{
RETURN_IF_ERROR
(
Normalize
());
needs_normalization_
=
false
;
}
// Initially show the actual inputs to be only the original
// inputs. If overrides are added later they will be added to
// 'inputs_'.
for
(
auto
&
pr
:
original_inputs_
)
{
inputs_
.
emplace
(
std
::
make_pair
(
pr
.
second
.
Name
(),
std
::
addressof
(
pr
.
second
)));
}
// Clear the timestamps
queue_start_ns_
=
0
;
batcher_start_ns_
=
0
;
#ifdef TRITON_ENABLE_STATS
request_start_ns_
=
0
;
#endif // TRITON_ENABLE_STATS
LOG_VERBOSE
(
1
)
<<
LogRequest
()
<<
"prepared: "
<<
*
this
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
Normalize
()
{
const
inference
::
ModelConfig
&
model_config
=
model_raw_
->
Config
();
// Fill metadata for raw input
if
(
!
raw_input_name_
.
empty
())
{
const
bool
has_multiple_inputs
=
(
original_inputs_
.
size
()
!=
1
)
||
(
model_config
.
input_size
()
!=
1
);
if
(
has_multiple_inputs
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"Raw request must only have 1 input (found "
+
std
::
to_string
(
original_inputs_
.
size
())
+
") to be deduced but got "
+
std
::
to_string
(
model_config
.
input_size
())
+
" inputs in '"
+
ModelName
()
+
"' model configuration"
);
}
auto
it
=
original_inputs_
.
begin
();
if
(
raw_input_name_
!=
it
->
first
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"Unexpected reference name for raw input '"
+
raw_input_name_
+
"' got '"
+
it
->
first
+
"'"
);
}
const
auto
&
config_input
=
model_config
.
input
(
0
);
auto
&
raw_input
=
it
->
second
;
std
::
vector
<
int64_t
>
shape
;
if
(
model_config
.
max_batch_size
()
!=
0
)
{
shape
.
emplace_back
(
1
);
}
int64_t
dynamic_axis
=
-
1
;
size_t
element_cnt
=
1
;
for
(
const
auto
&
dim
:
config_input
.
dims
())
{
if
(
dim
==
triton
::
common
::
WILDCARD_DIM
)
{
if
(
dynamic_axis
!=
-
1
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"The shape of the raw input '"
+
config_input
.
name
()
+
"' can not be deduced because there are more than one "
"variable-sized dimension"
);
}
dynamic_axis
=
shape
.
size
();
}
else
{
element_cnt
*=
(
size_t
)
dim
;
}
shape
.
emplace_back
(
dim
);
}
if
((
config_input
.
data_type
()
==
inference
::
DataType
::
TYPE_STRING
))
{
const
bool
has_one_element
=
(
dynamic_axis
==
-
1
)
&&
(
element_cnt
==
1
);
if
(
!
has_one_element
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"For BYTE datatype raw input, the "
"model must have input shape [1]"
);
}
// In the case of BYTE data type, we will prepend the byte size to follow
// the Triton convention.
raw_input_size_
=
raw_input
.
Data
()
->
TotalByteSize
();
RETURN_IF_ERROR
(
raw_input
.
PrependData
(
&
raw_input_size_
,
sizeof
(
uint32_t
),
TRITONSERVER_MEMORY_CPU
,
0
));
// Limit the BYTE raw input not to have host policy specific input for
// simplicity, such case won't happen given the current protocol spec.
// Will need to extend Input::PrependData() if needed.
if
(
!
raw_input
.
HostPolicyData
().
empty
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"Raw input with data associated "
"with a host policy setting is not "
"currently supported"
);
}
}
else
if
(
dynamic_axis
!=
-
1
)
{
shape
[
dynamic_axis
]
=
raw_input
.
Data
()
->
TotalByteSize
()
/
element_cnt
/
triton
::
common
::
GetDataTypeByteSize
(
config_input
.
data_type
());
}
raw_input
.
SetMetadata
(
config_input
.
name
(),
config_input
.
data_type
(),
shape
);
}
// Initialize the requested outputs to be used during inference. If
// original_requested_outputs_ is empty assume all outputs specified
// in model config are being requested.
requested_outputs_
.
clear
();
if
(
original_requested_outputs_
.
size
()
==
0
)
{
for
(
const
auto
&
output
:
model_config
.
output
())
{
requested_outputs_
.
insert
(
output
.
name
());
}
}
else
{
// Validate if the original requested output name exists in the
// model configuration.
for
(
const
auto
&
output_name
:
original_requested_outputs_
)
{
const
inference
::
ModelOutput
*
output_config
;
RETURN_IF_ERROR
(
model_raw_
->
GetOutput
(
output_name
,
&
output_config
));
}
}
// Make sure that the request is providing the number of inputs
// as is expected by the model.
if
((
original_inputs_
.
size
()
>
(
size_t
)
model_config
.
input_size
())
||
(
original_inputs_
.
size
()
<
model_raw_
->
RequiredInputCount
()))
{
// If no input is marked as optional, then use exact match error message
// for consistency / backward compatibility
if
((
size_t
)
model_config
.
input_size
()
==
model_raw_
->
RequiredInputCount
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"expected "
+
std
::
to_string
(
model_config
.
input_size
())
+
" inputs but got "
+
std
::
to_string
(
original_inputs_
.
size
())
+
" inputs for model '"
+
ModelName
()
+
"'"
);
}
else
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"expected number of inputs between "
+
std
::
to_string
(
model_raw_
->
RequiredInputCount
())
+
" and "
+
std
::
to_string
(
model_config
.
input_size
())
+
" but got "
+
std
::
to_string
(
original_inputs_
.
size
())
+
" inputs for model '"
+
ModelName
()
+
"'"
);
}
}
// Determine the batch size and shape of each input.
if
(
model_config
.
max_batch_size
()
==
0
)
{
// Model does not support Triton-style batching so set as
// batch-size 0 and leave the tensor shapes as they are.
batch_size_
=
0
;
for
(
auto
&
pr
:
original_inputs_
)
{
auto
&
input
=
pr
.
second
;
*
input
.
MutableShape
()
=
input
.
OriginalShape
();
}
}
else
{
// Model does support Triton-style batching so each input tensor
// must have the same first dimension which is the batch
// size. Adjust the shape of the input tensors to remove the batch
// dimension.
batch_size_
=
0
;
for
(
auto
&
pr
:
original_inputs_
)
{
auto
&
input
=
pr
.
second
;
// For a shape tensor, keep the tensor's shape as it is and mark
// that the input is a shape tensor.
const
inference
::
ModelInput
*
input_config
;
RETURN_IF_ERROR
(
model_raw_
->
GetInput
(
input
.
Name
(),
&
input_config
));
if
(
input_config
->
is_shape_tensor
())
{
*
input
.
MutableShape
()
=
input
.
OriginalShape
();
input
.
SetIsShapeTensor
(
true
);
continue
;
}
if
(
input
.
OriginalShape
().
size
()
==
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
input
.
Name
()
+
"' has no shape but model requires batch dimension for '"
+
ModelName
()
+
"'"
);
}
if
(
batch_size_
==
0
)
{
batch_size_
=
input
.
OriginalShape
()[
0
];
}
else
if
(
input
.
OriginalShape
()[
0
]
!=
batch_size_
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"input '"
+
input
.
Name
()
+
"' batch size does not match other inputs for '"
+
ModelName
()
+
"'"
);
}
input
.
MutableShape
()
->
assign
(
input
.
OriginalShape
().
begin
()
+
1
,
input
.
OriginalShape
().
end
());
}
}
// Make sure request batch-size doesn't exceed what is supported by
// the model.
if
((
int
)
batch_size_
>
model_config
.
max_batch_size
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"inference request batch-size must be <= "
+
std
::
to_string
(
model_config
.
max_batch_size
())
+
" for '"
+
ModelName
()
+
"'"
);
}
// Verify that each input shape is valid for the model, make
// adjustments for reshapes and find the total tensor size.
for
(
auto
&
pr
:
original_inputs_
)
{
const
inference
::
ModelInput
*
input_config
;
RETURN_IF_ERROR
(
model_raw_
->
GetInput
(
pr
.
second
.
Name
(),
&
input_config
));
auto
&
input
=
pr
.
second
;
auto
shape
=
input
.
MutableShape
();
if
(
input
.
DType
()
!=
input_config
->
data_type
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"inference input data-type is '"
+
std
::
string
(
triton
::
common
::
DataTypeToProtocolString
(
input
.
DType
()))
+
"', model expects '"
+
std
::
string
(
triton
::
common
::
DataTypeToProtocolString
(
input_config
->
data_type
()))
+
"' for '"
+
ModelName
()
+
"'"
);
}
// Validate input shape
{
bool
match_config
=
true
;
const
auto
&
config_dims
=
input_config
->
dims
();
const
auto
&
input_dims
=
*
shape
;
if
(
config_dims
.
size
()
!=
(
int64_t
)
input_dims
.
size
())
{
match_config
=
false
;
}
else
{
for
(
int
i
=
0
;
i
<
config_dims
.
size
();
++
i
)
{
if
(
input_dims
[
i
]
==
triton
::
common
::
WILDCARD_DIM
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"All input dimensions should be specified for input '"
+
pr
.
first
+
"' for model '"
+
ModelName
()
+
"', got "
+
triton
::
common
::
DimsListToString
(
input
.
OriginalShape
()));
}
else
if
(
(
config_dims
[
i
]
!=
triton
::
common
::
WILDCARD_DIM
)
&&
(
config_dims
[
i
]
!=
input_dims
[
i
]))
{
match_config
=
false
;
break
;
}
}
}
if
(
!
match_config
)
{
triton
::
common
::
DimsList
full_dims
;
if
(
model_config
.
max_batch_size
()
>
0
)
{
full_dims
.
Add
(
triton
::
common
::
WILDCARD_DIM
);
}
for
(
int
i
=
0
;
i
<
input_config
->
dims_size
();
++
i
)
{
full_dims
.
Add
(
input_config
->
dims
(
i
));
}
return
Status
(
Status
::
Code
::
INVALID_ARG
,
LogRequest
()
+
"unexpected shape for input '"
+
pr
.
first
+
"' for model '"
+
ModelName
()
+
"'. Expected "
+
triton
::
common
::
DimsListToString
(
full_dims
)
+
", got "
+
triton
::
common
::
DimsListToString
(
input
.
OriginalShape
()));
}
}
// If there is a reshape for this input then adjust them to
// match the reshape. As reshape may have variable-size
// dimensions, we need to record corresponding value so that we
// can set the value correctly for reshape.
if
(
input_config
->
has_reshape
())
{
std
::
deque
<
int64_t
>
variable_size_values
;
for
(
int64_t
idx
=
0
;
idx
<
input_config
->
dims_size
();
idx
++
)
{
if
(
input_config
->
dims
(
idx
)
==
-
1
)
{
variable_size_values
.
push_back
((
*
shape
)[
idx
]);
}
}
shape
->
clear
();
for
(
const
auto
&
dim
:
input_config
->
reshape
().
shape
())
{
if
(
dim
==
-
1
)
{
shape
->
push_back
(
variable_size_values
.
front
());
variable_size_values
.
pop_front
();
}
else
{
shape
->
push_back
(
dim
);
}
}
}
// Create shape with batch dimension.
// FIXME, should not need this!!
if
(
batch_size_
==
0
)
{
*
input
.
MutableShapeWithBatchDim
()
=
*
shape
;
}
else
{
input
.
MutableShapeWithBatchDim
()
->
clear
();
input
.
MutableShapeWithBatchDim
()
->
push_back
(
batch_size_
);
for
(
int64_t
d
:
*
shape
)
{
input
.
MutableShapeWithBatchDim
()
->
push_back
(
d
);
}
}
}
return
Status
::
Success
;
}
#ifdef TRITON_ENABLE_STATS
void
InferenceRequest
::
ReportStatistics
(
MetricModelReporter
*
metric_reporter
,
bool
success
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_end_ns
,
const
uint64_t
compute_output_start_ns
,
const
uint64_t
compute_end_ns
)
{
if
(
!
collect_stats_
)
{
return
;
}
#ifdef TRITON_ENABLE_TRACING
if
(
trace_
!=
nullptr
)
{
trace_
->
Report
(
TRITONSERVER_TRACE_COMPUTE_START
,
compute_start_ns
);
trace_
->
Report
(
TRITONSERVER_TRACE_COMPUTE_INPUT_END
,
compute_input_end_ns
);
trace_
->
Report
(
TRITONSERVER_TRACE_COMPUTE_OUTPUT_START
,
compute_output_start_ns
);
trace_
->
Report
(
TRITONSERVER_TRACE_COMPUTE_END
,
compute_end_ns
);
}
#endif // TRITON_ENABLE_TRACING
INFER_STATS_DECL_TIMESTAMP
(
request_end_ns
);
if
(
success
)
{
model_raw_
->
MutableStatsAggregator
()
->
UpdateSuccess
(
metric_reporter
,
std
::
max
(
1U
,
batch_size_
),
request_start_ns_
,
queue_start_ns_
,
compute_start_ns
,
compute_input_end_ns
,
compute_output_start_ns
,
compute_end_ns
,
request_end_ns
);
if
(
secondary_stats_aggregator_
!=
nullptr
)
{
secondary_stats_aggregator_
->
UpdateSuccess
(
nullptr
/* metric_reporter */
,
std
::
max
(
1U
,
batch_size_
),
request_start_ns_
,
queue_start_ns_
,
compute_start_ns
,
compute_input_end_ns
,
compute_output_start_ns
,
compute_end_ns
,
request_end_ns
);
}
}
else
{
model_raw_
->
MutableStatsAggregator
()
->
UpdateFailure
(
metric_reporter
,
request_start_ns_
,
request_end_ns
);
if
(
secondary_stats_aggregator_
!=
nullptr
)
{
secondary_stats_aggregator_
->
UpdateFailure
(
nullptr
/* metric_reporter */
,
request_start_ns_
,
request_end_ns
);
}
}
}
void
InferenceRequest
::
ReportStatisticsWithDuration
(
MetricModelReporter
*
metric_reporter
,
bool
success
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_duration_ns
,
const
uint64_t
compute_infer_duration_ns
,
const
uint64_t
compute_output_duration_ns
)
{
if
(
!
collect_stats_
)
{
return
;
}
INFER_STATS_DECL_TIMESTAMP
(
request_end_ns
);
if
(
success
)
{
model_raw_
->
MutableStatsAggregator
()
->
UpdateSuccessWithDuration
(
metric_reporter
,
std
::
max
(
1U
,
batch_size_
),
request_start_ns_
,
queue_start_ns_
,
compute_start_ns
,
request_end_ns
,
compute_input_duration_ns
,
compute_infer_duration_ns
,
compute_output_duration_ns
);
if
(
secondary_stats_aggregator_
!=
nullptr
)
{
secondary_stats_aggregator_
->
UpdateSuccessWithDuration
(
nullptr
/* metric_reporter */
,
std
::
max
(
1U
,
batch_size_
),
request_start_ns_
,
queue_start_ns_
,
compute_start_ns
,
request_end_ns
,
compute_input_duration_ns
,
compute_infer_duration_ns
,
compute_output_duration_ns
);
}
}
else
{
model_raw_
->
MutableStatsAggregator
()
->
UpdateFailure
(
metric_reporter
,
request_start_ns_
,
request_end_ns
);
if
(
secondary_stats_aggregator_
!=
nullptr
)
{
secondary_stats_aggregator_
->
UpdateFailure
(
nullptr
/* metric_reporter */
,
request_start_ns_
,
request_end_ns
);
}
}
}
void
InferenceRequest
::
ReportStatisticsCacheHit
(
MetricModelReporter
*
metric_reporter
)
{
// Capture end of request time
INFER_STATS_DECL_TIMESTAMP
(
request_end_ns
);
if
(
cache_lookup_start_ns_
>=
cache_lookup_end_ns_
)
{
LOG_WARNING
<<
LogRequest
()
<<
"Cache lookup timestamps were not set correctly. Cache "
"lookup duration stats may be incorrect."
;
}
const
uint64_t
cache_lookup_duration_ns
=
cache_lookup_end_ns_
-
cache_lookup_start_ns_
;
// Cache hit is always success
model_raw_
->
MutableStatsAggregator
()
->
UpdateSuccessCacheHit
(
metric_reporter
,
std
::
max
(
1U
,
batch_size_
),
request_start_ns_
,
queue_start_ns_
,
cache_lookup_start_ns_
,
request_end_ns
,
cache_lookup_duration_ns
);
if
(
secondary_stats_aggregator_
!=
nullptr
)
{
secondary_stats_aggregator_
->
UpdateSuccessCacheHit
(
nullptr
/* metric_reporter */
,
std
::
max
(
1U
,
batch_size_
),
request_start_ns_
,
queue_start_ns_
,
cache_lookup_start_ns_
,
request_end_ns
,
cache_lookup_duration_ns
);
}
}
void
InferenceRequest
::
ReportStatisticsCacheMiss
(
MetricModelReporter
*
metric_reporter
)
{
if
(
cache_lookup_start_ns_
>=
cache_lookup_end_ns_
)
{
LOG_WARNING
<<
LogRequest
()
<<
"Cache lookup timestamps were not set correctly. Cache "
"lookup duration stats may be incorrect."
;
}
if
(
cache_insertion_start_ns_
>=
cache_insertion_end_ns_
)
{
LOG_WARNING
<<
LogRequest
()
<<
"Cache insertion timestamps were not set correctly. Cache "
"insertion duration stats may be incorrect."
;
}
const
uint64_t
cache_lookup_duration_ns
=
cache_lookup_end_ns_
-
cache_lookup_start_ns_
;
const
uint64_t
cache_insertion_duration_ns
=
cache_insertion_end_ns_
-
cache_insertion_start_ns_
;
model_raw_
->
MutableStatsAggregator
()
->
UpdateSuccessCacheMiss
(
metric_reporter
,
cache_lookup_duration_ns
,
cache_insertion_duration_ns
);
if
(
secondary_stats_aggregator_
!=
nullptr
)
{
secondary_stats_aggregator_
->
UpdateSuccessCacheMiss
(
nullptr
/* metric_reporter */
,
cache_lookup_duration_ns
,
cache_insertion_duration_ns
);
}
}
#endif // TRITON_ENABLE_STATS
//
// Input
//
InferenceRequest
::
Input
::
Input
()
:
is_shape_tensor_
(
false
),
data_
(
new
MemoryReference
),
has_host_policy_specific_data_
(
false
)
{
}
InferenceRequest
::
Input
::
Input
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
int64_t
*
shape
,
const
uint64_t
dim_count
)
:
name_
(
name
),
datatype_
(
datatype
),
original_shape_
(
shape
,
shape
+
dim_count
),
is_shape_tensor_
(
false
),
data_
(
new
MemoryReference
),
has_host_policy_specific_data_
(
false
)
{
}
InferenceRequest
::
Input
::
Input
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
)
:
name_
(
name
),
datatype_
(
datatype
),
original_shape_
(
shape
),
is_shape_tensor_
(
false
),
data_
(
new
MemoryReference
),
has_host_policy_specific_data_
(
false
)
{
}
void
InferenceRequest
::
Input
::
SetMetadata
(
const
std
::
string
&
name
,
const
inference
::
DataType
&
dt
,
const
std
::
vector
<
int64_t
>&
shape
)
{
name_
=
name
;
datatype_
=
dt
;
original_shape_
=
shape
;
}
Status
InferenceRequest
::
Input
::
SetIsShapeTensor
(
const
bool
is_shape_tensor
)
{
is_shape_tensor_
=
is_shape_tensor
;
return
Status
::
Success
;
}
const
std
::
shared_ptr
<
Memory
>&
InferenceRequest
::
Input
::
Data
(
const
std
::
string
&
host_policy_name
)
const
{
auto
device_data
=
host_policy_data_map_
.
find
(
host_policy_name
);
if
(
device_data
==
host_policy_data_map_
.
end
())
{
// Fall back on default data if there is no data that has been added for
// this host policy
return
data_
;
}
return
device_data
->
second
;
}
Status
InferenceRequest
::
Input
::
AppendData
(
const
void
*
base
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
if
(
byte_size
>
0
)
{
std
::
static_pointer_cast
<
MemoryReference
>
(
data_
)
->
AddBuffer
(
static_cast
<
const
char
*>
(
base
),
byte_size
,
memory_type
,
memory_type_id
);
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
AppendDataWithBufferAttributes
(
const
void
*
base
,
BufferAttributes
*
buffer_attributes
)
{
if
(
buffer_attributes
->
ByteSize
()
>
0
)
{
std
::
static_pointer_cast
<
MemoryReference
>
(
data_
)
->
AddBuffer
(
static_cast
<
const
char
*>
(
base
),
buffer_attributes
);
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
AppendDataWithHostPolicy
(
const
void
*
base
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
,
const
char
*
host_policy_name
)
{
auto
device_data
=
host_policy_data_map_
.
find
(
host_policy_name
);
has_host_policy_specific_data_
=
true
;
if
(
device_data
==
host_policy_data_map_
.
end
())
{
auto
insert_pair
=
host_policy_data_map_
.
insert
(
std
::
make_pair
(
std
::
string
(
host_policy_name
),
new
MemoryReference
));
device_data
=
insert_pair
.
first
;
}
if
(
byte_size
>
0
)
{
std
::
static_pointer_cast
<
MemoryReference
>
(
device_data
->
second
)
->
AddBuffer
(
static_cast
<
const
char
*>
(
base
),
byte_size
,
memory_type
,
memory_type_id
);
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
PrependData
(
const
void
*
base
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
if
(
byte_size
>
0
)
{
std
::
static_pointer_cast
<
MemoryReference
>
(
data_
)
->
AddBufferFront
(
static_cast
<
const
char
*>
(
base
),
byte_size
,
memory_type
,
memory_type_id
);
}
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
SetData
(
const
std
::
shared_ptr
<
Memory
>&
data
)
{
if
(
data_
->
TotalByteSize
()
!=
0
)
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"input '"
+
name_
+
"' already has data, can't overwrite"
);
}
data_
=
data
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
SetData
(
const
std
::
string
&
host_policy_name
,
const
std
::
shared_ptr
<
Memory
>&
data
)
{
if
(
host_policy_data_map_
.
find
(
host_policy_name
)
!=
host_policy_data_map_
.
end
())
{
return
Status
(
Status
::
Code
::
INVALID_ARG
,
"input '"
+
name_
+
"' already has data for host policy '"
+
host_policy_name
+
"', can't overwrite"
);
}
host_policy_data_map_
.
emplace
(
host_policy_name
,
data
);
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
RemoveAllData
()
{
data_
=
std
::
make_shared
<
MemoryReference
>
();
host_policy_data_map_
.
clear
();
has_host_policy_specific_data_
=
false
;
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
DataBuffer
(
const
size_t
idx
,
const
void
**
base
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
{
*
base
=
data_
->
BufferAt
(
idx
,
byte_size
,
memory_type
,
memory_type_id
);
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
DataBufferAttributes
(
const
size_t
idx
,
const
void
**
base
,
BufferAttributes
**
buffer_attributes
)
const
{
*
base
=
data_
->
BufferAt
(
idx
,
buffer_attributes
);
return
Status
::
Success
;
}
Status
InferenceRequest
::
Input
::
DataBufferForHostPolicy
(
const
size_t
idx
,
const
void
**
base
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
,
const
std
::
string
&
host_policy_name
)
const
{
auto
device_data
=
host_policy_data_map_
.
find
(
host_policy_name
);
if
(
device_data
==
host_policy_data_map_
.
end
())
{
// Return data buffer if there is no host-policy specific buffer available
*
base
=
data_
->
BufferAt
(
idx
,
byte_size
,
memory_type
,
memory_type_id
);
}
else
{
*
base
=
device_data
->
second
->
BufferAt
(
idx
,
byte_size
,
memory_type
,
memory_type_id
);
}
return
Status
::
Success
;
}
size_t
InferenceRequest
::
Input
::
DataBufferCountForHostPolicy
(
const
std
::
string
&
host_policy_name
)
const
{
auto
policy_data
=
host_policy_data_map_
.
find
(
host_policy_name
);
if
(
policy_data
!=
host_policy_data_map_
.
end
())
{
return
policy_data
->
second
->
BufferCount
();
}
return
data_
->
BufferCount
();
}
InferenceRequest
::
SequenceId
::
SequenceId
()
:
sequence_label_
(
""
),
sequence_index_
(
0
),
id_type_
(
InferenceRequest
::
SequenceId
::
DataType
::
UINT64
)
{
}
InferenceRequest
::
SequenceId
::
SequenceId
(
const
std
::
string
&
sequence_label
)
:
sequence_label_
(
sequence_label
),
sequence_index_
(
0
),
id_type_
(
InferenceRequest
::
SequenceId
::
DataType
::
STRING
)
{
}
InferenceRequest
::
SequenceId
::
SequenceId
(
uint64_t
sequence_index
)
:
sequence_label_
(
""
),
sequence_index_
(
sequence_index
),
id_type_
(
InferenceRequest
::
SequenceId
::
DataType
::
UINT64
)
{
}
InferenceRequest
::
SequenceId
&
InferenceRequest
::
SequenceId
::
operator
=
(
const
std
::
string
&
rhs
)
{
sequence_label_
=
rhs
;
sequence_index_
=
0
;
id_type_
=
InferenceRequest
::
SequenceId
::
DataType
::
STRING
;
return
*
this
;
}
InferenceRequest
::
SequenceId
&
InferenceRequest
::
SequenceId
::
operator
=
(
const
uint64_t
rhs
)
{
sequence_label_
=
""
;
sequence_index_
=
rhs
;
id_type_
=
InferenceRequest
::
SequenceId
::
DataType
::
UINT64
;
return
*
this
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
&
request
)
{
out
<<
"[0x"
<<
std
::
addressof
(
request
)
<<
"] "
<<
"request id: "
<<
request
.
Id
()
<<
", model: "
<<
request
.
ModelName
()
<<
", requested version: "
<<
request
.
RequestedModelVersion
()
<<
", actual version: "
<<
request
.
ActualModelVersion
()
<<
", flags: 0x"
<<
std
::
hex
<<
request
.
Flags
()
<<
std
::
dec
<<
", correlation id: "
<<
request
.
CorrelationId
()
<<
", batch size: "
<<
request
.
BatchSize
()
<<
", priority: "
<<
request
.
Priority
()
<<
", timeout (us): "
<<
request
.
TimeoutMicroseconds
()
<<
std
::
endl
;
out
<<
"original inputs:"
<<
std
::
endl
;
for
(
const
auto
&
itr
:
request
.
OriginalInputs
())
{
out
<<
"[0x"
<<
std
::
addressof
(
itr
.
second
)
<<
"] "
<<
itr
.
second
<<
std
::
endl
;
}
out
<<
"override inputs:"
<<
std
::
endl
;
for
(
const
auto
&
itr
:
request
.
OverrideInputs
())
{
out
<<
"[0x"
<<
itr
.
second
.
get
()
<<
"] "
<<
*
itr
.
second
<<
std
::
endl
;
}
out
<<
"inputs:"
<<
std
::
endl
;
for
(
const
auto
&
itr
:
request
.
ImmutableInputs
())
{
out
<<
"[0x"
<<
itr
.
second
<<
"] "
<<
*
itr
.
second
<<
std
::
endl
;
}
out
<<
"original requested outputs:"
<<
std
::
endl
;
for
(
const
auto
&
name
:
request
.
OriginalRequestedOutputs
())
{
out
<<
name
<<
std
::
endl
;
}
out
<<
"requested outputs:"
<<
std
::
endl
;
for
(
const
auto
&
name
:
request
.
ImmutableRequestedOutputs
())
{
out
<<
name
<<
std
::
endl
;
}
return
out
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
::
Input
&
input
)
{
out
<<
"input: "
<<
input
.
Name
()
<<
", type: "
<<
triton
::
common
::
DataTypeToProtocolString
(
input
.
DType
())
<<
", original shape: "
<<
triton
::
common
::
DimsListToString
(
input
.
OriginalShape
())
<<
", batch + shape: "
<<
triton
::
common
::
DimsListToString
(
input
.
ShapeWithBatchDim
())
<<
", shape: "
<<
triton
::
common
::
DimsListToString
(
input
.
Shape
());
if
(
input
.
IsShapeTensor
())
{
out
<<
", is_shape_tensor: True"
;
}
return
out
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
::
SequenceId
&
sequence_id
)
{
switch
(
sequence_id
.
Type
())
{
case
InferenceRequest
::
SequenceId
::
DataType
::
STRING
:
out
<<
sequence_id
.
StringValue
();
break
;
case
InferenceRequest
::
SequenceId
::
DataType
::
UINT64
:
out
<<
sequence_id
.
UnsignedIntValue
();
break
;
default:
out
<<
sequence_id
.
UnsignedIntValue
();
break
;
}
return
out
;
}
bool
operator
==
(
const
InferenceRequest
::
SequenceId
lhs
,
const
InferenceRequest
::
SequenceId
rhs
)
{
if
(
lhs
.
Type
()
==
rhs
.
Type
())
{
switch
(
lhs
.
Type
())
{
case
InferenceRequest
::
SequenceId
::
DataType
::
STRING
:
return
lhs
.
StringValue
()
==
rhs
.
StringValue
();
case
InferenceRequest
::
SequenceId
::
DataType
::
UINT64
:
return
lhs
.
UnsignedIntValue
()
==
rhs
.
UnsignedIntValue
();
default:
return
lhs
.
UnsignedIntValue
()
==
rhs
.
UnsignedIntValue
();
}
}
else
{
return
false
;
}
}
bool
operator
!=
(
const
InferenceRequest
::
SequenceId
lhs
,
const
InferenceRequest
::
SequenceId
rhs
)
{
return
!
(
lhs
==
rhs
);
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_request.h
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "buffer_attributes.h"
#include "infer_response.h"
#include "infer_stats.h"
#include "infer_trace.h"
#include "memory.h"
#include "response_allocator.h"
#include "sequence_state.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
class
Model
;
class
InferenceServer
;
class
MetricModelReporter
;
//
// An inference request. A request can be used multiple times for
// inference but before each inference run, PrepareForInference() must
// be called to verify and prepare the request. Verification involves
// ensuring that any changes made since the last inference are
// valid. Preparing involves removing/resetting any state left over
// from the previous inference.
//
class
InferenceRequest
{
public:
// Input tensor
class
Input
{
public:
Input
();
Input
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
);
Input
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
int64_t
*
shape
,
const
uint64_t
dim_count
);
// Set the name, data type and original shape of the input tensor.
void
SetMetadata
(
const
std
::
string
&
name
,
const
inference
::
DataType
&
dt
,
const
std
::
vector
<
int64_t
>&
shape
);
// The name of the input tensor. There is no mutable operator for
// the name because it is used in a InferenceRequest map and a
// mutable method would allow it to get out-of-sync.
const
std
::
string
&
Name
()
const
{
return
name_
;
}
// Data type of the input tensor.
inference
::
DataType
DType
()
const
{
return
datatype_
;
}
// The original shape of the input tensor.
const
std
::
vector
<
int64_t
>&
OriginalShape
()
const
{
return
original_shape_
;
}
// The shape of the input tensor after normalization. This shape
// is the original shape modified as required/expected by
// inference processing.
const
std
::
vector
<
int64_t
>&
Shape
()
const
{
return
shape_
;
}
std
::
vector
<
int64_t
>*
MutableShape
()
{
return
&
shape_
;
}
// FIXME. Should not need these functions. All shapes kept here
// should include the batch dimension instead of breaking the same
// into batch + shape.
const
std
::
vector
<
int64_t
>&
ShapeWithBatchDim
()
const
{
return
shape_with_batch_dim_
;
}
std
::
vector
<
int64_t
>*
MutableShapeWithBatchDim
()
{
return
&
shape_with_batch_dim_
;
}
// Return true if host-specific data was added for this input
bool
HasHostPolicySpecificData
()
const
{
return
has_host_policy_specific_data_
;
}
// Whether or not the input is a tensorrt shape tensor
bool
IsShapeTensor
()
const
{
return
is_shape_tensor_
;
}
// Set the input to be treated as a shape tensor.
Status
SetIsShapeTensor
(
const
bool
is_shape_tensor
);
// The data for this input.
const
std
::
shared_ptr
<
Memory
>&
Data
()
const
{
return
data_
;
}
// The data for this input for a specific device
const
std
::
shared_ptr
<
Memory
>&
Data
(
const
std
::
string
&
host_policy_name
)
const
;
// Return all host policy data set for this input
const
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
Memory
>>&
HostPolicyData
()
const
{
return
host_policy_data_map_
;
}
// Set the data for this input. Error if input already has some
// data.
Status
SetData
(
const
std
::
shared_ptr
<
Memory
>&
data
);
// Set the data associated with the host policy for this input.
// Return error if input already has some data.
Status
SetData
(
const
std
::
string
&
host_policy_name
,
const
std
::
shared_ptr
<
Memory
>&
data
);
// Append a new buffer of data to this input.
Status
AppendData
(
const
void
*
base
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
Status
AppendDataWithHostPolicy
(
const
void
*
base
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
,
const
char
*
host_policy_name
);
Status
AppendDataWithBufferAttributes
(
const
void
*
base
,
BufferAttributes
*
buffer_attributes
);
// Prepend a new buffer of data to this input.
Status
PrependData
(
const
void
*
base
,
size_t
byte_size
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
);
// Remove all existing data for the input.
Status
RemoveAllData
();
// Get the number of buffers containing the input tensor data.
size_t
DataBufferCount
()
const
{
return
data_
->
BufferCount
();
}
// Get the number of buffers containing the input tensor data with
// host policy. If there are no buffers corresponding to the specific
// host policy, the number of buffers in the fallback input data is
// returned.
size_t
DataBufferCountForHostPolicy
(
const
std
::
string
&
host_policy_name
)
const
;
// Get the 'idx' buffer containing a contiguous chunk of bytes for
// the input. Return error is 'idx' refers to a buffer that does
// not exist. Return a pointer to the chunk in 'base' and the
// size of the chunk in 'byte_size'. 'memory_type' acts as
// both input and output. On input 'memory_type' is the buffer
// memory type preferred by the function caller. On return
// 'memory_type' gives the actual memory type of the chunk pointed
// to by 'base'. 'memory_type_id' acts as both input and
// output. On input 'memory_type_id' is the buffer memory type id
// preferred by the function caller. On return 'memory_type_id'
// gives the actual memory type id of the chunk pointed to by
// 'base'.
Status
DataBuffer
(
const
size_t
idx
,
const
void
**
base
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
const
;
// Get the buffer attributes associated with 'idx' buffer.
Status
DataBufferAttributes
(
const
size_t
idx
,
const
void
**
base
,
BufferAttributes
**
buffer_attributes
)
const
;
// Get the 'idx' buffer containing a contiguous chunk of bytes for
// the input. Return error is 'idx' refers to a buffer that does
// not exist. Return a pointer to the chunk in 'base' and the
// size of the chunk in 'byte_size'. 'memory_type' acts as
// both input and output. On input 'memory_type' is the buffer
// memory type preferred by the function caller. On return
// 'memory_type' gives the actual memory type of the chunk pointed
// to by 'base'. 'memory_type_id' acts as both input and
// output. On input 'memory_type_id' is the buffer memory type id
// preferred by the function caller. On return 'memory_type_id'
// gives the actual memory type id of the chunk pointed to by
// 'base'.
Status
DataBufferForHostPolicy
(
const
size_t
idx
,
const
void
**
base
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
,
const
std
::
string
&
host_policy_name
)
const
;
private:
DISALLOW_COPY_AND_ASSIGN
(
Input
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
::
Input
&
input
);
std
::
string
name_
;
inference
::
DataType
datatype_
;
std
::
vector
<
int64_t
>
original_shape_
;
std
::
vector
<
int64_t
>
shape_
;
std
::
vector
<
int64_t
>
shape_with_batch_dim_
;
bool
is_shape_tensor_
;
std
::
shared_ptr
<
Memory
>
data_
;
bool
has_host_policy_specific_data_
;
// A map of host policy to input data memory
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
Memory
>>
host_policy_data_map_
;
};
// Sequence ID can be either a 64 bit integer or a string.
// This class implements the SequenceId type
class
SequenceId
{
public:
enum
class
DataType
{
UINT64
,
STRING
};
SequenceId
();
SequenceId
(
const
std
::
string
&
sequence_label
);
SequenceId
(
uint64_t
sequence_index
);
SequenceId
&
operator
=
(
const
SequenceId
&
rhs
)
=
default
;
SequenceId
&
operator
=
(
const
std
::
string
&
rhs
);
SequenceId
&
operator
=
(
const
uint64_t
rhs
);
// Functions that help determine exact type of sequence Id
DataType
Type
()
const
{
return
id_type_
;
}
bool
InSequence
()
const
{
return
((
sequence_label_
!=
""
)
||
(
sequence_index_
!=
0
));
}
// Get the value of the SequenceId based on the type
const
std
::
string
&
StringValue
()
const
{
return
sequence_label_
;
}
uint64_t
UnsignedIntValue
()
const
{
return
sequence_index_
;
}
private:
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
::
SequenceId
&
correlation_id
);
friend
bool
operator
==
(
const
SequenceId
lhs
,
const
SequenceId
rhs
);
friend
bool
operator
!=
(
const
SequenceId
lhs
,
const
SequenceId
rhs
);
std
::
string
sequence_label_
;
uint64_t
sequence_index_
;
DataType
id_type_
;
};
// InferenceRequest
//
// The two constructors are identical except one takes model as a
// shared pointer and the other as a raw pointer. The shared pointer
// version is the primary one and acts to keep the model alive as
// long as the request is in flight. The raw pointer version is used
// only for cases where the model itself is issuing a request
// (e.g. warmup) and no shared pointer version of the model exists
// (because we aren't using shared_from_this).
InferenceRequest
(
const
std
::
shared_ptr
<
Model
>&
model
,
const
int64_t
requested_model_version
);
InferenceRequest
(
Model
*
model
,
const
int64_t
requested_model_version
);
const
std
::
string
&
ModelName
()
const
;
int64_t
RequestedModelVersion
()
const
{
return
requested_model_version_
;
}
int64_t
ActualModelVersion
()
const
;
const
std
::
string
&
Id
()
const
{
return
id_
;
}
void
SetId
(
const
std
::
string
&
i
)
{
id_
=
i
;
}
// Return string for logging request ID
std
::
string
LogRequest
()
const
{
std
::
string
id
=
Id
();
if
(
id
.
empty
())
{
id
=
"<id_unknown>"
;
}
return
std
::
string
(
"[request id: "
)
+
id
+
"] "
;
}
// Flags for the request, union of TRITONSERVER_RequestFlag.
uint32_t
Flags
()
const
{
return
flags_
;
}
void
SetFlags
(
uint32_t
f
)
{
flags_
=
f
;
}
const
SequenceId
&
CorrelationId
()
const
{
return
correlation_id_
;
}
void
SetCorrelationId
(
const
SequenceId
&
c
)
{
correlation_id_
=
c
;
}
// The batch size of the request, as understood by Triton. A
// batch-size of 0 indicates that the model doesn't support batching
// in a way that Triton understands. Batch size is not set
// explicitly so there is no setter for it. It is set when the
// request is normalized.
uint32_t
BatchSize
()
const
{
return
batch_size_
;
}
uint32_t
Priority
()
const
{
return
priority_
;
}
void
SetPriority
(
uint32_t
p
);
uint64_t
TimeoutMicroseconds
()
const
{
return
timeout_us_
;
}
void
SetTimeoutMicroseconds
(
uint64_t
t
)
{
timeout_us_
=
t
;
}
uint64_t
CacheKey
()
const
{
return
cache_key_
;
}
// It is up to the user to update the cache_key_ if modifying any hashable
// fields of the request after cache_key_is_set_ has been set to true.
void
SetCacheKey
(
uint64_t
key
)
{
cache_key_
=
key
;
cache_key_is_set_
=
true
;
}
bool
CacheKeyIsSet
()
const
{
return
cache_key_is_set_
;
}
#ifdef TRITON_ENABLE_TRACING
const
std
::
shared_ptr
<
InferenceTraceProxy
>&
Trace
()
const
{
return
trace_
;
}
std
::
shared_ptr
<
InferenceTraceProxy
>*
MutableTrace
()
{
return
&
trace_
;
}
void
SetTrace
(
const
std
::
shared_ptr
<
InferenceTraceProxy
>&
trace
)
{
trace_
=
trace
;
response_factory_
->
SetTrace
(
trace
);
}
void
ReleaseTrace
()
{
trace_
=
nullptr
;
response_factory_
->
ReleaseTrace
();
}
Status
TraceInputTensors
(
TRITONSERVER_InferenceTraceActivity
activity
,
const
std
::
string
&
msg
);
#endif // TRITON_ENABLE_TRACING
// The original inputs are the inputs added to the request before
// the inference execution (that is before
// TRITONSERVER_ServerInferAsync is called). Once execution has
// started the original inputs should not be modified until
// execution completes (and those modifications will apply to the
// next inference execution).
Status
MutableOriginalInput
(
const
std
::
string
&
name
,
Input
**
input
);
std
::
unordered_map
<
std
::
string
,
Input
>*
MutableOriginalInputs
()
{
return
&
original_inputs_
;
}
const
std
::
unordered_map
<
std
::
string
,
Input
>&
OriginalInputs
()
const
{
return
original_inputs_
;
}
// The override inputs are the inputs added to the request after
// inference execution has started (that is after
// TRITONSERVER_ServerInferAsync or equivalent is called). During
// inference processing, if Triton needs to change an original input
// it will add an override instead of changing the original. Triton
// will also use an override if it needs to add a new input to the
// request. Overrides are recorded as shared_ptr so that the same
// override can be used efficiently multiple times or even in
// multiple requests simultaneously. Must be careful not to modify
// an override input if it is being shared unless you want that
// change to be reflected in all requests that hold that override
// input. Override inputs within a specific request are not
// persisted across inference calls.
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Input
>>*
MutableOverrideInputs
()
{
return
&
override_inputs_
;
}
const
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Input
>>&
OverrideInputs
()
const
{
return
override_inputs_
;
}
// Get an input taking into account both original inputs and
// overrides. If an override input is available use it, otherwise
// use the original input. Accessing inputs via this method is not
// valid until after PrepareForInference is called.
Status
ImmutableInput
(
const
std
::
string
&
name
,
const
Input
**
input
)
const
;
const
std
::
unordered_map
<
std
::
string
,
Input
*>&
ImmutableInputs
()
const
{
return
inputs_
;
}
// The original requested outputs are the requested outputs added to
// the request before the inference execution (that is before
// TRITONSERVER_ServerInferAsync is called). Once execution has
// started the original requested outputs should not be modified
// until execution completes (and those modifications will apply to
// the next inference execution).
const
std
::
set
<
std
::
string
>&
OriginalRequestedOutputs
()
const
{
return
original_requested_outputs_
;
}
// Get the requested outputs that should be used during
// inference. Accessing outputs via this method is not valid until
// after PrepareForInference is called.
const
std
::
set
<
std
::
string
>&
ImmutableRequestedOutputs
()
const
{
return
(
requested_outputs_
.
empty
())
?
original_requested_outputs_
:
requested_outputs_
;
}
// Get the response factory.
const
std
::
shared_ptr
<
InferenceResponseFactory
>&
ResponseFactory
()
const
{
return
response_factory_
;
}
// Add an original input to the request. If 'input' is non-null
// return a pointer to the newly added input.
Status
AddOriginalInput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
int64_t
*
shape
,
const
uint64_t
dim_count
,
Input
**
input
=
nullptr
);
Status
AddOriginalInput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
,
Input
**
input
=
nullptr
);
// Add an original raw input to the request. If 'input' is non-null
// return a pointer to the newly added input.
Status
AddRawInput
(
const
std
::
string
&
name
,
Input
**
input
=
nullptr
);
// Remove a single original input or all inputs.
Status
RemoveOriginalInput
(
const
std
::
string
&
name
);
Status
RemoveAllOriginalInputs
();
// Add an override input to the request. If 'input' is non-null
// return a pointer to the newly added input.
// FIXME passing batch size is special handling for backend API.
// For override input, the 'shape' is without batch dimension for
// backends that implemented w/o backend API (which need correct
// input.Shape()), but backend API uses input.ShapeWithBatchDim().
Status
AddOverrideInput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
int64_t
batch_size
,
const
std
::
vector
<
int64_t
>&
shape
,
std
::
shared_ptr
<
Input
>*
input
=
nullptr
);
// Add an override input to the request.
Status
AddOverrideInput
(
const
std
::
shared_ptr
<
Input
>&
input
);
// Request an original requested output.
Status
AddOriginalRequestedOutput
(
const
std
::
string
&
name
);
// Remove a single original requested output or all requested
// outputs.
Status
RemoveOriginalRequestedOutput
(
const
std
::
string
&
name
);
Status
RemoveAllOriginalRequestedOutputs
();
// Initialize the release callback for the request.
Status
SetReleaseCallback
(
TRITONSERVER_InferenceRequestReleaseFn_t
release_fn
,
void
*
release_userp
)
{
release_fn_
=
release_fn
;
release_userp_
=
release_userp
;
return
Status
::
Success
;
}
// Initialize the response factory that is to be used with any
// responses produced for this request.
Status
SetResponseCallback
(
const
ResponseAllocator
*
allocator
,
void
*
alloc_userp
,
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn
,
void
*
response_userp
)
{
response_factory_
.
reset
(
new
InferenceResponseFactory
(
model_shared_
,
id_
,
allocator
,
alloc_userp
,
response_fn
,
response_userp
,
response_delegator_
));
return
Status
::
Success
;
}
// Returns the preferred memory type and memory type ID of the output buffer
// for the request. 'name' and 'byte_size' are optional and set to nullptr
// if not specified, if provided, they give the allocator more information.
// 'memory_type' and 'memory_type_id' are also used as input to provide types
// preferred by the caller.
// Status::Code::UNAVAILABLE will be returned if output properties are not
// available.
Status
OutputBufferProperties
(
const
char
*
name
,
size_t
*
byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
);
// Add a callback to be invoked on releasing the request object from Triton.
// Multile callbacks can be added by calling this function in order,
// and they will be invoked in reversed order.
Status
AddInternalReleaseCallback
(
std
::
function
<
void
()
>&&
callback
)
{
release_callbacks_
.
emplace_back
(
std
::
move
(
callback
));
return
Status
::
Success
;
}
// Add a delegator to be invoked on sending the responses of this request.
// The response will be passed to 'delegator' and 'delegator' must call the
// InferenceResponse::Send() to send the response.
Status
SetResponseDelegator
(
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>&&
delegator
)
{
response_delegator_
=
std
::
move
(
delegator
);
return
response_factory_
->
SetResponseDelegator
(
response_delegator_
);
}
Status
SetSequenceStates
(
const
std
::
shared_ptr
<
SequenceStates
>&
sequence_states
)
{
sequence_states_
=
sequence_states
;
return
Status
::
Success
;
}
Status
LoadInputStates
();
const
std
::
shared_ptr
<
SequenceStates
>&
GetSequenceStates
()
const
{
return
sequence_states_
;
}
// Prepare this request for inference.
Status
PrepareForInference
();
// Run this inference request using the model associated with the
// request. If Status::Success is returned then the call has taken
// ownership of the request object and so 'request' will be
// nullptr. If non-success is returned then the caller still retains
// ownership of 'request'.
static
Status
Run
(
std
::
unique_ptr
<
InferenceRequest
>&
request
);
// Send an error response for this request. If 'status' is Success
// then no response is sent and the request is not released (even if
// 'release_request' is true). Because this is sending an error it
// is assumed that this is the last response for the request and so
// the FINAL flag is set in the response callback. If
// 'release_request' is true then the release callback is called for
// this request and ownership is given to the callback. Thus, if
// 'release_request' is true 'request' is returned as nullptr.
static
void
RespondIfError
(
std
::
unique_ptr
<
InferenceRequest
>&
request
,
const
Status
&
status
,
const
bool
release_request
=
false
);
// Send an error response to a set of 'requests'. If 'status' is
// Success then no responses are sent and the requests are not
// released (even if 'release_request' is true). Because this is
// sending an error it is assumed that this is the last response for
// the requests and so the FINAL flag is set in the response
// callbacks. If 'release_request' is true then the release callback
// is called for each request, and the request ownership is given to
// the callback. Thus, if 'release_request' is true 'requests' is
// returned with all nullptrs.
static
void
RespondIfError
(
std
::
vector
<
std
::
unique_ptr
<
InferenceRequest
>>&
requests
,
const
Status
&
status
,
const
bool
release_requests
=
false
);
// Release the request. Call the release callback and transfer
// ownership of the request to the callback. On return 'request' is
// nullptr.
static
void
Release
(
std
::
unique_ptr
<
InferenceRequest
>&&
request
,
const
uint32_t
release_flags
);
// Create a copy of 'from' suitable for use as a "null" request as
// required for the direct sequence batcher. The returned copy will
// contain only the minimum content required for a null request.
// The statistics of the copy will not be collected.
static
InferenceRequest
*
CopyAsNull
(
const
InferenceRequest
&
from
);
uint64_t
QueueStartNs
()
const
{
return
queue_start_ns_
;
}
uint64_t
CaptureQueueStartNs
()
{
queue_start_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
queue_start_ns_
;
}
uint64_t
CacheLookupStartNs
()
const
{
return
cache_lookup_start_ns_
;
}
uint64_t
CaptureCacheLookupStartNs
()
{
cache_lookup_start_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
cache_lookup_start_ns_
;
}
uint64_t
CacheLookupEndNs
()
const
{
return
cache_lookup_end_ns_
;
}
uint64_t
CaptureCacheLookupEndNs
()
{
cache_lookup_end_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
cache_lookup_end_ns_
;
}
uint64_t
CacheInsertionStartNs
()
const
{
return
cache_insertion_start_ns_
;
}
uint64_t
CaptureCacheInsertionStartNs
()
{
cache_insertion_start_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
cache_insertion_start_ns_
;
}
uint64_t
CacheInsertionEndNs
()
const
{
return
cache_insertion_end_ns_
;
}
uint64_t
CaptureCacheInsertionEndNs
()
{
cache_insertion_end_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
cache_insertion_end_ns_
;
}
uint64_t
BatcherStartNs
()
const
{
return
batcher_start_ns_
;
}
uint64_t
CaptureBatcherStartNs
()
{
batcher_start_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
batcher_start_ns_
;
}
#ifdef TRITON_ENABLE_STATS
uint64_t
RequestStartNs
()
const
{
return
request_start_ns_
;
}
uint64_t
CaptureRequestStartNs
()
{
request_start_ns_
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
return
request_start_ns_
;
}
// Report the statistics to stats collectors associated with the request.
// Duration and timestamps provide two granularities for stats collectors.
void
ReportStatistics
(
MetricModelReporter
*
metric_reporter
,
bool
success
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_end_ns
,
const
uint64_t
compute_output_start_ns
,
const
uint64_t
compute_end_ns
);
// Report the statistics to stats collectors associated with the request.
// Duration and timestamps provide two granularities for stats collectors.
void
ReportStatisticsWithDuration
(
MetricModelReporter
*
metric_reporter
,
bool
success
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_duration_ns
,
const
uint64_t
compute_infer_duration_ns
,
const
uint64_t
compute_output_duration_ns
);
// Report the statistics to stats collectors associated with the request on
// response cache hits.
void
ReportStatisticsCacheHit
(
MetricModelReporter
*
metric_reporter
);
// Report the statistics to stats collectors associated with the request on
// response cache misses and update request duration to include cache
// insertion time.
void
ReportStatisticsCacheMiss
(
MetricModelReporter
*
metric_reporter
);
// Statistics for each request are aggregated into the corresponding
// model's statistics. Optionally this function may be used to
// add an additional aggregator where statistics are also aggregated.
void
SetSecondaryStatsAggregator
(
InferenceStatsAggregator
*
secondary_stats_aggregator
)
{
secondary_stats_aggregator_
=
secondary_stats_aggregator
;
}
#endif // TRITON_ENABLE_STATS
private:
DISALLOW_COPY_AND_ASSIGN
(
InferenceRequest
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
&
request
);
Status
Normalize
();
// Has anything in the request potentially changed in a way that
// causes normalization to be required when preparing the request
// for inference.
bool
needs_normalization_
;
// The model associated with this request. For most requests
// model_shared_ will be non-null and will act to keep the model
// alive as long as this request is live. In this case model_raw_
// will be the raw pointer from the shared pointer. For cases where
// the model itself created the request (like running requests for
// warmup), model_shared_ will be nullptr, but model_raw_ will
// still be defined. Thus model_raw_ is always defined and should
// always to used to access the model.
std
::
shared_ptr
<
Model
>
model_shared_
;
Model
*
model_raw_
;
// The model version as requested and based on version policy the
// specific version that is actually used for inference.
int64_t
requested_model_version_
;
int64_t
actual_model_version_
;
std
::
string
id_
;
uint32_t
flags_
;
SequenceId
correlation_id_
;
uint32_t
batch_size_
;
uint32_t
priority_
;
uint64_t
timeout_us_
;
uint64_t
cache_key_
=
0
;
// Helper to determine if request was successfully hashed
// and cache_key_ field is valid
bool
cache_key_is_set_
=
false
;
std
::
unordered_map
<
std
::
string
,
Input
>
original_inputs_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Input
>>
override_inputs_
;
std
::
unordered_map
<
std
::
string
,
Input
*>
inputs_
;
std
::
set
<
std
::
string
>
original_requested_outputs_
;
std
::
string
raw_input_name_
;
uint32_t
raw_input_size_
;
// requested_outputs_ is to be used post-normalization. It will be
// empty unless it differs from original_requested_outputs_, so
// typically should access it through ImmutableRequestedOutputs.
std
::
set
<
std
::
string
>
requested_outputs_
;
// The release function and user pointer for this request.
TRITONSERVER_InferenceRequestReleaseFn_t
release_fn_
;
void
*
release_userp_
;
// Additional release callbacks invoked before 'release_fn_'.
std
::
vector
<
std
::
function
<
void
()
>>
release_callbacks_
;
// Delegator to be invoked on sending responses.
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>
response_delegator_
;
// The response factory associated with this request.
std
::
shared_ptr
<
InferenceResponseFactory
>
response_factory_
;
// Request timestamps. Queue start is needed for schedulers even
// when statistics are not being collected.
uint64_t
queue_start_ns_
;
// Cache lookup start/end timestamps. Cache manages its own stats even
// when statistics are not being colleceted.
uint64_t
cache_lookup_start_ns_
;
uint64_t
cache_lookup_end_ns_
;
// Cache insertion start/end timestamps. Cache manages its own stats even
// when statistics are not being colleceted.
uint64_t
cache_insertion_start_ns_
;
uint64_t
cache_insertion_end_ns_
;
// Dedicated timestamp for batcher internal which can diverge from
// queue start timestamp to provide accurate queue time without affecting
// batcher functionalities.
uint64_t
batcher_start_ns_
;
// Whether the stats of the request should be collected.
bool
collect_stats_
;
#ifdef TRITON_ENABLE_STATS
uint64_t
request_start_ns_
;
InferenceStatsAggregator
*
secondary_stats_aggregator_
=
nullptr
;
#endif // TRITON_ENABLE_STATS
#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this request.
std
::
shared_ptr
<
InferenceTraceProxy
>
trace_
;
#endif // TRITON_ENABLE_TRACING
// Sequence I/O states used for implicit state.
std
::
shared_ptr
<
SequenceStates
>
sequence_states_
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
&
request
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
::
Input
&
input
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceRequest
::
SequenceId
&
sequence_id
);
bool
operator
==
(
const
InferenceRequest
::
SequenceId
lhs
,
const
InferenceRequest
::
SequenceId
rhs
);
}}
// namespace triton::core
namespace
std
{
using
namespace
triton
::
core
;
template
<
>
class
hash
<
InferenceRequest
::
SequenceId
>
{
public:
size_t
operator
()(
const
InferenceRequest
::
SequenceId
&
sequence_id
)
const
{
if
(
sequence_id
.
Type
()
==
InferenceRequest
::
SequenceId
::
DataType
::
STRING
)
{
return
std
::
hash
<
std
::
string
>
{}(
sequence_id
.
StringValue
());
}
return
std
::
hash
<
uint64_t
>
{}(
sequence_id
.
UnsignedIntValue
());
}
};
}
// namespace std
3rdparty/core-r22.12/src/infer_response.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_response.h"
#include "model.h"
#include "model_config_utils.h"
#include "server.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
//
// InferenceResponseFactory
//
Status
InferenceResponseFactory
::
CreateResponse
(
std
::
unique_ptr
<
InferenceResponse
>*
response
)
const
{
response
->
reset
(
new
InferenceResponse
(
model_
,
id_
,
allocator_
,
alloc_userp_
,
response_fn_
,
response_userp_
,
response_delegator_
));
#ifdef TRITON_ENABLE_TRACING
(
*
response
)
->
SetTrace
(
trace_
);
#endif // TRITON_ENABLE_TRACING
return
Status
::
Success
;
}
Status
InferenceResponseFactory
::
SendFlags
(
const
uint32_t
flags
)
const
{
if
(
response_delegator_
!=
nullptr
)
{
std
::
unique_ptr
<
InferenceResponse
>
response
(
new
InferenceResponse
(
response_fn_
,
response_userp_
));
response_delegator_
(
std
::
move
(
response
),
flags
);
}
else
{
void
*
userp
=
response_userp_
;
response_fn_
(
nullptr
/* response */
,
flags
,
userp
);
}
return
Status
::
Success
;
}
//
// InferenceResponse
//
InferenceResponse
::
InferenceResponse
(
const
std
::
shared_ptr
<
Model
>&
model
,
const
std
::
string
&
id
,
const
ResponseAllocator
*
allocator
,
void
*
alloc_userp
,
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn
,
void
*
response_userp
,
const
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>&
delegator
)
:
model_
(
model
),
id_
(
id
),
allocator_
(
allocator
),
alloc_userp_
(
alloc_userp
),
response_fn_
(
response_fn
),
response_userp_
(
response_userp
),
response_delegator_
(
delegator
),
null_response_
(
false
)
{
// If the allocator has a start_fn then invoke it.
TRITONSERVER_ResponseAllocatorStartFn_t
start_fn
=
allocator_
->
StartFn
();
if
(
start_fn
!=
nullptr
)
{
LOG_TRITONSERVER_ERROR
(
start_fn
(
reinterpret_cast
<
TRITONSERVER_ResponseAllocator
*>
(
const_cast
<
ResponseAllocator
*>
(
allocator_
)),
alloc_userp_
),
"response allocation start failed"
);
}
}
InferenceResponse
::
InferenceResponse
(
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn
,
void
*
response_userp
)
:
response_fn_
(
response_fn
),
response_userp_
(
response_userp
),
null_response_
(
true
)
{
}
const
std
::
string
&
InferenceResponse
::
ModelName
()
const
{
static
const
std
::
string
unknown
(
"<unknown>"
);
return
(
model_
==
nullptr
)
?
unknown
:
model_
->
Name
();
}
int64_t
InferenceResponse
::
ActualModelVersion
()
const
{
return
(
model_
==
nullptr
)
?
-
1
:
model_
->
Version
();
}
Status
InferenceResponse
::
AddParameter
(
const
char
*
name
,
const
char
*
value
)
{
parameters_
.
emplace_back
(
name
,
value
);
return
Status
::
Success
;
}
Status
InferenceResponse
::
AddParameter
(
const
char
*
name
,
const
int64_t
value
)
{
parameters_
.
emplace_back
(
name
,
value
);
return
Status
::
Success
;
}
Status
InferenceResponse
::
AddParameter
(
const
char
*
name
,
const
bool
value
)
{
parameters_
.
emplace_back
(
name
,
value
);
return
Status
::
Success
;
}
Status
InferenceResponse
::
AddOutput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
,
InferenceResponse
::
Output
**
output
)
{
outputs_
.
emplace_back
(
name
,
datatype
,
shape
,
allocator_
,
alloc_userp_
);
LOG_VERBOSE
(
1
)
<<
"add response output: "
<<
outputs_
.
back
();
if
(
model_
!=
nullptr
)
{
const
inference
::
ModelOutput
*
output_config
;
RETURN_IF_ERROR
(
model_
->
GetOutput
(
name
,
&
output_config
));
if
(
output_config
->
has_reshape
())
{
const
bool
has_batch_dim
=
(
model_
->
Config
().
max_batch_size
()
>
0
);
outputs_
.
back
().
Reshape
(
has_batch_dim
,
output_config
);
}
}
if
(
output
!=
nullptr
)
{
*
output
=
std
::
addressof
(
outputs_
.
back
());
}
return
Status
::
Success
;
}
Status
InferenceResponse
::
AddOutput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
std
::
vector
<
int64_t
>&&
shape
,
InferenceResponse
::
Output
**
output
)
{
outputs_
.
emplace_back
(
name
,
datatype
,
std
::
move
(
shape
),
allocator_
,
alloc_userp_
);
LOG_VERBOSE
(
1
)
<<
"add response output: "
<<
outputs_
.
back
();
if
(
model_
!=
nullptr
)
{
const
inference
::
ModelOutput
*
output_config
;
RETURN_IF_ERROR
(
model_
->
GetOutput
(
name
,
&
output_config
));
if
(
output_config
->
has_reshape
())
{
const
bool
has_batch_dim
=
(
model_
->
Config
().
max_batch_size
()
>
0
);
outputs_
.
back
().
Reshape
(
has_batch_dim
,
output_config
);
}
}
if
(
output
!=
nullptr
)
{
*
output
=
std
::
addressof
(
outputs_
.
back
());
}
return
Status
::
Success
;
}
Status
InferenceResponse
::
ClassificationLabel
(
const
InferenceResponse
::
Output
&
output
,
const
uint32_t
class_index
,
const
char
**
label
)
const
{
const
auto
&
label_provider
=
model_
->
GetLabelProvider
();
const
std
::
string
&
l
=
label_provider
->
GetLabel
(
output
.
Name
(),
class_index
);
if
(
l
.
empty
())
{
*
label
=
nullptr
;
}
else
{
*
label
=
l
.
c_str
();
}
return
Status
::
Success
;
}
Status
InferenceResponse
::
Send
(
std
::
unique_ptr
<
InferenceResponse
>&&
response
,
const
uint32_t
flags
)
{
#ifdef TRITON_ENABLE_TRACING
response
->
TraceOutputTensors
(
TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT
,
"InferenceResponse Send"
);
#endif // TRITON_ENABLE_TRACING
if
(
response
->
response_delegator_
!=
nullptr
)
{
auto
ldelegator
=
std
::
move
(
response
->
response_delegator_
);
ldelegator
(
std
::
move
(
response
),
flags
);
return
Status
::
Success
;
}
void
*
userp
=
response
->
response_userp_
;
if
(
response
->
null_response_
)
{
response
->
response_fn_
(
nullptr
/* response */
,
flags
,
userp
);
}
else
{
auto
&
response_fn
=
response
->
response_fn_
;
response_fn
(
reinterpret_cast
<
TRITONSERVER_InferenceResponse
*>
(
response
.
release
()),
flags
,
userp
);
}
return
Status
::
Success
;
}
Status
InferenceResponse
::
SendWithStatus
(
std
::
unique_ptr
<
InferenceResponse
>&&
response
,
const
uint32_t
flags
,
const
Status
&
status
)
{
response
->
status_
=
status
;
return
InferenceResponse
::
Send
(
std
::
move
(
response
),
flags
);
}
#ifdef TRITON_ENABLE_TRACING
Status
InferenceResponse
::
TraceOutputTensors
(
TRITONSERVER_InferenceTraceActivity
activity
,
const
std
::
string
&
msg
)
{
const
auto
&
outputs
=
this
->
Outputs
();
uint32_t
output_count
=
outputs
.
size
();
for
(
uint32_t
idx
=
0
;
idx
<
output_count
;
++
idx
)
{
const
Output
&
output
=
outputs
[
idx
];
// output data
const
char
*
cname
=
output
.
Name
().
c_str
();
TRITONSERVER_DataType
datatype
=
DataTypeToTriton
(
output
.
DType
());
const
std
::
vector
<
int64_t
>&
oshape
=
output
.
Shape
();
const
int64_t
*
shape
=
&
oshape
[
0
];
uint64_t
dim_count
=
oshape
.
size
();
const
void
*
base
;
size_t
byte_size
;
TRITONSERVER_MemoryType
memory_type
;
int64_t
memory_type_id
;
void
*
userp
;
Status
status
=
output
.
DataBuffer
(
&
base
,
&
byte_size
,
&
memory_type
,
&
memory_type_id
,
&
userp
);
if
(
!
status
.
IsOk
())
{
LOG_STATUS_ERROR
(
status
,
std
::
string
(
TRITONSERVER_InferenceTraceActivityString
(
activity
))
+
": "
+
msg
+
": fail to get data buffer: "
+
status
.
Message
());
return
status
;
}
INFER_TRACE_TENSOR_ACTIVITY
(
this
->
trace_
,
activity
,
cname
,
datatype
,
base
,
byte_size
,
shape
,
dim_count
,
memory_type
,
memory_type_id
);
}
return
Status
::
Success
;
}
#endif // TRITON_ENABLE_TRACING
//
// InferenceResponse::Output
//
InferenceResponse
::
Output
::~
Output
()
{
Status
status
=
ReleaseDataBuffer
();
if
(
!
status
.
IsOk
())
{
LOG_ERROR
<<
"failed to release buffer for output '"
<<
name_
<<
"': "
<<
status
.
AsString
();
}
}
void
InferenceResponse
::
Output
::
Reshape
(
const
bool
has_batch_dim
,
const
inference
::
ModelOutput
*
output_config
)
{
std
::
deque
<
int64_t
>
variable_size_values
;
const
int64_t
batch_dim
=
(
has_batch_dim
&&
(
shape_
.
size
()
>
0
))
?
shape_
[
0
]
:
-
1
;
const
size_t
batch_dim_offset
=
(
has_batch_dim
)
?
1
:
0
;
const
auto
&
from_shape
=
output_config
->
reshape
().
shape
();
const
auto
&
to_shape
=
output_config
->
dims
();
for
(
int64_t
idx
=
0
;
idx
<
from_shape
.
size
();
idx
++
)
{
if
(
from_shape
[
idx
]
==
-
1
)
{
variable_size_values
.
push_back
(
shape_
[
idx
+
batch_dim_offset
]);
}
}
shape_
.
clear
();
if
(
batch_dim
>=
0
)
{
shape_
.
push_back
(
batch_dim
);
}
for
(
const
auto
&
dim
:
to_shape
)
{
if
(
dim
==
-
1
)
{
shape_
.
push_back
(
variable_size_values
.
front
());
variable_size_values
.
pop_front
();
}
else
{
shape_
.
push_back
(
dim
);
}
}
}
Status
InferenceResponse
::
Output
::
DataBuffer
(
const
void
**
buffer
,
size_t
*
buffer_byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
,
void
**
userp
)
const
{
*
buffer
=
allocated_buffer_
;
*
buffer_byte_size
=
buffer_attributes_
.
ByteSize
();
*
memory_type
=
buffer_attributes_
.
MemoryType
();
*
memory_type_id
=
buffer_attributes_
.
MemoryTypeId
();
*
userp
=
allocated_userp_
;
return
Status
::
Success
;
}
Status
InferenceResponse
::
Output
::
AllocateDataBuffer
(
void
**
buffer
,
size_t
buffer_byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
)
{
if
(
allocated_buffer_
!=
nullptr
)
{
return
Status
(
Status
::
Code
::
ALREADY_EXISTS
,
"allocated buffer for output '"
+
name_
+
"' already exists"
);
}
TRITONSERVER_MemoryType
actual_memory_type
=
*
memory_type
;
int64_t
actual_memory_type_id
=
*
memory_type_id
;
void
*
alloc_buffer_userp
=
nullptr
;
RETURN_IF_TRITONSERVER_ERROR
(
allocator_
->
AllocFn
()(
reinterpret_cast
<
TRITONSERVER_ResponseAllocator
*>
(
const_cast
<
ResponseAllocator
*>
(
allocator_
)),
name_
.
c_str
(),
buffer_byte_size
,
*
memory_type
,
*
memory_type_id
,
alloc_userp_
,
buffer
,
&
alloc_buffer_userp
,
&
actual_memory_type
,
&
actual_memory_type_id
));
// Only call the buffer attributes API if it is set.
if
(
allocator_
->
BufferAttributesFn
()
!=
nullptr
)
{
RETURN_IF_TRITONSERVER_ERROR
(
allocator_
->
BufferAttributesFn
()(
reinterpret_cast
<
TRITONSERVER_ResponseAllocator
*>
(
const_cast
<
ResponseAllocator
*>
(
allocator_
)),
name_
.
c_str
(),
reinterpret_cast
<
TRITONSERVER_BufferAttributes
*>
(
&
buffer_attributes_
),
alloc_userp_
,
alloc_buffer_userp
));
}
allocated_buffer_
=
*
buffer
;
buffer_attributes_
.
SetByteSize
(
buffer_byte_size
);
buffer_attributes_
.
SetMemoryType
(
actual_memory_type
);
buffer_attributes_
.
SetMemoryTypeId
(
actual_memory_type_id
);
allocated_userp_
=
alloc_buffer_userp
;
*
memory_type
=
actual_memory_type
;
*
memory_type_id
=
actual_memory_type_id
;
return
Status
::
Success
;
}
Status
InferenceResponse
::
Output
::
ReleaseDataBuffer
()
{
TRITONSERVER_Error
*
err
=
nullptr
;
if
(
allocated_buffer_
!=
nullptr
)
{
err
=
allocator_
->
ReleaseFn
()(
reinterpret_cast
<
TRITONSERVER_ResponseAllocator
*>
(
const_cast
<
ResponseAllocator
*>
(
allocator_
)),
allocated_buffer_
,
allocated_userp_
,
buffer_attributes_
.
ByteSize
(),
buffer_attributes_
.
MemoryType
(),
buffer_attributes_
.
MemoryTypeId
());
}
allocated_buffer_
=
nullptr
;
buffer_attributes_
.
SetByteSize
(
0
);
buffer_attributes_
.
SetMemoryType
(
TRITONSERVER_MEMORY_CPU
);
buffer_attributes_
.
SetMemoryTypeId
(
0
);
allocated_userp_
=
nullptr
;
RETURN_IF_TRITONSERVER_ERROR
(
err
);
return
Status
::
Success
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceResponse
&
response
)
{
out
<<
"[0x"
<<
std
::
addressof
(
response
)
<<
"] "
<<
"response id: "
<<
response
.
Id
()
<<
", model: "
<<
response
.
ModelName
()
<<
", actual version: "
<<
response
.
ActualModelVersion
()
<<
std
::
endl
;
out
<<
"status:"
<<
response
.
ResponseStatus
().
AsString
()
<<
std
::
endl
;
out
<<
"outputs:"
<<
std
::
endl
;
for
(
const
auto
&
output
:
response
.
Outputs
())
{
out
<<
"[0x"
<<
std
::
addressof
(
output
)
<<
"] "
<<
output
<<
std
::
endl
;
}
return
out
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceResponse
::
Output
&
output
)
{
out
<<
"output: "
<<
output
.
Name
()
<<
", type: "
<<
triton
::
common
::
DataTypeToProtocolString
(
output
.
DType
())
<<
", shape: "
<<
triton
::
common
::
DimsListToString
(
output
.
Shape
());
return
out
;
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_response.h
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <deque>
#include <functional>
#include <string>
#include <vector>
#include "buffer_attributes.h"
#include "constants.h"
#include "infer_parameter.h"
#include "infer_trace.h"
#include "response_allocator.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
class
Model
;
class
InferenceResponse
;
//
// An inference response factory.
//
class
InferenceResponseFactory
{
public:
InferenceResponseFactory
()
=
default
;
InferenceResponseFactory
(
const
std
::
shared_ptr
<
Model
>&
model
,
const
std
::
string
&
id
,
const
ResponseAllocator
*
allocator
,
void
*
alloc_userp
,
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn
,
void
*
response_userp
,
const
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>&
delegator
)
:
model_
(
model
),
id_
(
id
),
allocator_
(
allocator
),
alloc_userp_
(
alloc_userp
),
response_fn_
(
response_fn
),
response_userp_
(
response_userp
),
response_delegator_
(
delegator
)
{
}
const
ResponseAllocator
*
Allocator
()
{
return
allocator_
;
}
void
*
AllocatorUserp
()
{
return
alloc_userp_
;
}
Status
SetResponseDelegator
(
const
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>&
delegator
)
{
response_delegator_
=
delegator
;
return
Status
::
Success
;
}
// Create a new response.
Status
CreateResponse
(
std
::
unique_ptr
<
InferenceResponse
>*
response
)
const
;
// Send a "null" response with 'flags'.
Status
SendFlags
(
const
uint32_t
flags
)
const
;
#ifdef TRITON_ENABLE_TRACING
const
std
::
shared_ptr
<
InferenceTraceProxy
>&
Trace
()
const
{
return
trace_
;
}
void
SetTrace
(
const
std
::
shared_ptr
<
InferenceTraceProxy
>&
trace
)
{
trace_
=
trace
;
}
void
ReleaseTrace
()
{
trace_
=
nullptr
;
}
#endif // TRITON_ENABLE_TRACING
private:
// The model associated with this factory. For normal
// requests/responses this will always be defined and acts to keep
// the model loaded as long as this factory is live. It may be
// nullptr for cases where the model itself created the request
// (like running requests for warmup) and so must protect any uses
// to handle the nullptr case.
std
::
shared_ptr
<
Model
>
model_
;
// The ID of the corresponding request that should be included in every
// response. This is a property that can be optionally provided by the user.
std
::
string
id_
;
// The response allocator and user pointer. The 'allocator_' is a
// raw pointer because it is owned by the client, and the client is
// responsible for ensuring that the lifetime of the allocator
// extends longer that any request or response that depend on the
// allocator.
const
ResponseAllocator
*
allocator_
;
void
*
alloc_userp_
;
// The response callback function and user pointer.
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn_
;
void
*
response_userp_
;
// Delegator to be invoked on sending responses.
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>
response_delegator_
;
#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this response.
std
::
shared_ptr
<
InferenceTraceProxy
>
trace_
;
#endif // TRITON_ENABLE_TRACING
};
//
// An inference response.
//
class
InferenceResponse
{
public:
// Output tensor
class
Output
{
public:
Output
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
,
const
ResponseAllocator
*
allocator
,
void
*
alloc_userp
)
:
name_
(
name
),
datatype_
(
datatype
),
shape_
(
shape
),
allocator_
(
allocator
),
alloc_userp_
(
alloc_userp
),
allocated_buffer_
(
nullptr
)
{
}
Output
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
std
::
vector
<
int64_t
>&&
shape
,
const
ResponseAllocator
*
allocator
,
void
*
alloc_userp
)
:
name_
(
name
),
datatype_
(
datatype
),
shape_
(
std
::
move
(
shape
)),
allocator_
(
allocator
),
alloc_userp_
(
alloc_userp
),
allocated_buffer_
(
nullptr
)
{
}
~
Output
();
// The name of the output tensor.
const
std
::
string
&
Name
()
const
{
return
name_
;
}
// Data type of the output tensor.
inference
::
DataType
DType
()
const
{
return
datatype_
;
}
// The shape of the output tensor.
const
std
::
vector
<
int64_t
>&
Shape
()
const
{
return
shape_
;
}
BufferAttributes
*
GetBufferAttributes
()
{
return
&
buffer_attributes_
;
}
// Reshape the output tensor. This function must only be called
// for outputs that have respace specified in the model
// configuration.
void
Reshape
(
const
bool
has_batch_dim
,
const
inference
::
ModelOutput
*
output_config
);
// Get information about the buffer allocated for this output
// tensor's data. If no buffer is allocated 'buffer' will return
// nullptr and the other returned values will be undefined.
Status
DataBuffer
(
const
void
**
buffer
,
size_t
*
buffer_byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
,
void
**
userp
)
const
;
// Allocate the buffer that should be used for this output
// tensor's data. 'buffer' must return a buffer of size
// 'buffer_byte_size'. 'memory_type' acts as both input and
// output. On input gives the buffer memory type preferred by the
// caller and on return holds the actual memory type of
// 'buffer'. 'memory_type_id' acts as both input and output. On
// input gives the buffer memory type id preferred by the caller
// and returns the actual memory type id of 'buffer'. Only a
// single buffer may be allocated for the output at any time, so
// multiple calls to AllocateDataBuffer without intervening
// ReleaseDataBuffer call will result in an error.
Status
AllocateDataBuffer
(
void
**
buffer
,
const
size_t
buffer_byte_size
,
TRITONSERVER_MemoryType
*
memory_type
,
int64_t
*
memory_type_id
);
// Release the buffer that was previously allocated by
// AllocateDataBuffer(). Do nothing if AllocateDataBuffer() has
// not been called.
Status
ReleaseDataBuffer
();
private:
DISALLOW_COPY_AND_ASSIGN
(
Output
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceResponse
::
Output
&
output
);
std
::
string
name_
;
inference
::
DataType
datatype_
;
std
::
vector
<
int64_t
>
shape_
;
// The response allocator and user pointer.
const
ResponseAllocator
*
allocator_
;
void
*
alloc_userp_
;
// Information about the buffer allocated by
// AllocateDataBuffer(). This information is needed by
// DataBuffer() and ReleaseDataBuffer().
void
*
allocated_buffer_
;
BufferAttributes
buffer_attributes_
;
void
*
allocated_userp_
;
};
// InferenceResponse
InferenceResponse
(
const
std
::
shared_ptr
<
Model
>&
model
,
const
std
::
string
&
id
,
const
ResponseAllocator
*
allocator
,
void
*
alloc_userp
,
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn
,
void
*
response_userp
,
const
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>&
delegator
);
// "null" InferenceResponse is a special instance of InferenceResponse which
// contains minimal information for calling InferenceResponse::Send,
// InferenceResponse::NullResponse. nullptr will be passed as response in
// 'response_fn'.
InferenceResponse
(
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn
,
void
*
response_userp
);
const
std
::
string
&
Id
()
const
{
return
id_
;
}
const
std
::
string
&
ModelName
()
const
;
int64_t
ActualModelVersion
()
const
;
const
Status
&
ResponseStatus
()
const
{
return
status_
;
}
// The response parameters.
const
std
::
deque
<
InferenceParameter
>&
Parameters
()
const
{
return
parameters_
;
}
// Add an parameter to the response.
Status
AddParameter
(
const
char
*
name
,
const
char
*
value
);
Status
AddParameter
(
const
char
*
name
,
const
int64_t
value
);
Status
AddParameter
(
const
char
*
name
,
const
bool
value
);
// The response outputs.
const
std
::
deque
<
Output
>&
Outputs
()
const
{
return
outputs_
;
}
// Add an output to the response. If 'output' is non-null
// return a pointer to the newly added output.
Status
AddOutput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
const
std
::
vector
<
int64_t
>&
shape
,
Output
**
output
=
nullptr
);
Status
AddOutput
(
const
std
::
string
&
name
,
const
inference
::
DataType
datatype
,
std
::
vector
<
int64_t
>&&
shape
,
Output
**
output
=
nullptr
);
// Get the classification label associated with an output. Return
// 'label' == nullptr if no label.
Status
ClassificationLabel
(
const
Output
&
output
,
const
uint32_t
class_index
,
const
char
**
label
)
const
;
// Send the response with success status. Calling this function
// releases ownership of the response object and gives it to the
// callback function.
static
Status
Send
(
std
::
unique_ptr
<
InferenceResponse
>&&
response
,
const
uint32_t
flags
);
// Send the response with explicit status. Calling this function
// releases ownership of the response object and gives it to the
// callback function.
static
Status
SendWithStatus
(
std
::
unique_ptr
<
InferenceResponse
>&&
response
,
const
uint32_t
flags
,
const
Status
&
status
);
#ifdef TRITON_ENABLE_TRACING
const
std
::
shared_ptr
<
InferenceTraceProxy
>&
Trace
()
const
{
return
trace_
;
}
void
SetTrace
(
const
std
::
shared_ptr
<
InferenceTraceProxy
>&
trace
)
{
trace_
=
trace
;
}
void
ReleaseTrace
()
{
trace_
=
nullptr
;
}
#endif // TRITON_ENABLE_TRACING
private:
DISALLOW_COPY_AND_ASSIGN
(
InferenceResponse
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceResponse
&
response
);
#ifdef TRITON_ENABLE_TRACING
Status
TraceOutputTensors
(
TRITONSERVER_InferenceTraceActivity
activity
,
const
std
::
string
&
msg
);
#endif // TRITON_ENABLE_TRACING
// The model associated with this factory. For normal
// requests/responses this will always be defined and acts to keep
// the model loaded as long as this factory is live. It may be
// nullptr for cases where the model itself created the request
// (like running requests for warmup) and so must protect any uses
// to handle the nullptr case.
std
::
shared_ptr
<
Model
>
model_
;
// The ID of the corresponding request that should be included in
// every response.
std
::
string
id_
;
// Error status for the response.
Status
status_
;
// The parameters of the response. Use a deque so that there is no
// reallocation.
std
::
deque
<
InferenceParameter
>
parameters_
;
// The result tensors. Use a deque so that there is no reallocation.
std
::
deque
<
Output
>
outputs_
;
// The response allocator and user pointer.
const
ResponseAllocator
*
allocator_
;
void
*
alloc_userp_
;
// The response callback function and user pointer.
TRITONSERVER_InferenceResponseCompleteFn_t
response_fn_
;
void
*
response_userp_
;
// Delegator to be invoked on sending responses.
std
::
function
<
void
(
std
::
unique_ptr
<
InferenceResponse
>&&
,
const
uint32_t
)
>
response_delegator_
;
bool
null_response_
;
#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this response.
std
::
shared_ptr
<
InferenceTraceProxy
>
trace_
;
#endif // TRITON_ENABLE_TRACING
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceResponse
&
response
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
InferenceResponse
::
Output
&
output
);
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_stats.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_stats.h"
#include <time.h>
#include "metric_model_reporter.h"
#include "metrics.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
#ifdef TRITON_ENABLE_STATS
void
InferenceStatsAggregator
::
UpdateFailure
(
MetricModelReporter
*
metric_reporter
,
const
uint64_t
request_start_ns
,
const
uint64_t
request_end_ns
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
infer_stats_
.
failure_count_
++
;
infer_stats_
.
failure_duration_ns_
+=
(
request_end_ns
-
request_start_ns
);
#ifdef TRITON_ENABLE_METRICS
if
(
metric_reporter
!=
nullptr
)
{
metric_reporter
->
MetricInferenceFailure
().
Increment
(
1
);
}
#endif // TRITON_ENABLE_METRICS
}
void
InferenceStatsAggregator
::
UpdateSuccess
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
request_start_ns
,
const
uint64_t
queue_start_ns
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_end_ns
,
const
uint64_t
compute_output_start_ns
,
const
uint64_t
compute_end_ns
,
const
uint64_t
request_end_ns
)
{
const
uint64_t
compute_input_duration_ns
=
compute_input_end_ns
-
compute_start_ns
;
const
uint64_t
compute_infer_duration_ns
=
compute_output_start_ns
-
compute_input_end_ns
;
const
uint64_t
compute_output_duration_ns
=
compute_end_ns
-
compute_output_start_ns
;
UpdateSuccessWithDuration
(
metric_reporter
,
batch_size
,
request_start_ns
,
queue_start_ns
,
compute_start_ns
,
request_end_ns
,
compute_input_duration_ns
,
compute_infer_duration_ns
,
compute_output_duration_ns
);
}
void
InferenceStatsAggregator
::
UpdateSuccessWithDuration
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
request_start_ns
,
const
uint64_t
queue_start_ns
,
const
uint64_t
compute_start_ns
,
const
uint64_t
request_end_ns
,
const
uint64_t
compute_input_duration_ns
,
const
uint64_t
compute_infer_duration_ns
,
const
uint64_t
compute_output_duration_ns
)
{
const
uint64_t
request_duration_ns
=
request_end_ns
-
request_start_ns
;
const
uint64_t
queue_duration_ns
=
compute_start_ns
-
queue_start_ns
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
inference_count_
+=
batch_size
;
infer_stats_
.
success_count_
++
;
infer_stats_
.
request_duration_ns_
+=
request_duration_ns
;
infer_stats_
.
queue_duration_ns_
+=
queue_duration_ns
;
infer_stats_
.
compute_input_duration_ns_
+=
compute_input_duration_ns
;
infer_stats_
.
compute_infer_duration_ns_
+=
compute_infer_duration_ns
;
infer_stats_
.
compute_output_duration_ns_
+=
compute_output_duration_ns
;
#ifdef TRITON_ENABLE_METRICS
if
(
metric_reporter
!=
nullptr
)
{
metric_reporter
->
MetricInferenceSuccess
().
Increment
(
1
);
metric_reporter
->
MetricInferenceCount
().
Increment
(
batch_size
);
metric_reporter
->
MetricInferenceRequestDuration
().
Increment
(
request_duration_ns
/
1000
);
metric_reporter
->
MetricInferenceQueueDuration
().
Increment
(
queue_duration_ns
/
1000
);
metric_reporter
->
MetricInferenceComputeInputDuration
().
Increment
(
compute_input_duration_ns
/
1000
);
metric_reporter
->
MetricInferenceComputeInferDuration
().
Increment
(
compute_infer_duration_ns
/
1000
);
metric_reporter
->
MetricInferenceComputeOutputDuration
().
Increment
(
compute_output_duration_ns
/
1000
);
}
#endif // TRITON_ENABLE_METRICS
}
// Currently cache hits will not go to the inference backend where metrics
// are typically updated, so this method allows us to update relevant metrics
// from a metric reporter rather than going through the backend.
void
InferenceStatsAggregator
::
UpdateSuccessCacheHit
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
request_start_ns
,
const
uint64_t
queue_start_ns
,
const
uint64_t
cache_lookup_start_ns
,
const
uint64_t
request_end_ns
,
const
uint64_t
cache_hit_lookup_duration_ns
)
{
const
uint64_t
request_duration_ns
=
request_end_ns
-
request_start_ns
;
const
uint64_t
queue_duration_ns
=
cache_lookup_start_ns
-
queue_start_ns
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
infer_stats_
.
success_count_
++
;
infer_stats_
.
request_duration_ns_
+=
request_duration_ns
;
infer_stats_
.
queue_duration_ns_
+=
queue_duration_ns
;
infer_stats_
.
cache_hit_count_
++
;
infer_stats_
.
cache_hit_lookup_duration_ns_
+=
cache_hit_lookup_duration_ns
;
#ifdef TRITON_ENABLE_METRICS
if
(
metric_reporter
!=
nullptr
)
{
metric_reporter
->
MetricInferenceSuccess
().
Increment
(
1
);
metric_reporter
->
MetricInferenceRequestDuration
().
Increment
(
request_duration_ns
/
1000
);
metric_reporter
->
MetricInferenceQueueDuration
().
Increment
(
queue_duration_ns
/
1000
);
metric_reporter
->
MetricCacheHitCount
().
Increment
(
1
);
metric_reporter
->
MetricCacheHitLookupDuration
().
Increment
(
cache_hit_lookup_duration_ns
/
1000
);
}
#endif // TRITON_ENABLE_METRICS
}
// Cache misses will go to the inference backend where metrics are typically
// updated, but cache insertion happens after the inference backend finishes.
// So we use this method to update cache miss stats and adjust the request
// duration to include cache insertion time.
void
InferenceStatsAggregator
::
UpdateSuccessCacheMiss
(
MetricModelReporter
*
metric_reporter
,
const
uint64_t
cache_miss_lookup_duration_ns
,
const
uint64_t
cache_miss_insertion_duration_ns
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
const
uint64_t
cache_miss_duration_ns
=
cache_miss_lookup_duration_ns
+
cache_miss_insertion_duration_ns
;
infer_stats_
.
request_duration_ns_
+=
cache_miss_duration_ns
;
infer_stats_
.
cache_miss_count_
++
;
infer_stats_
.
cache_miss_lookup_duration_ns_
+=
cache_miss_lookup_duration_ns
;
infer_stats_
.
cache_miss_insertion_duration_ns_
+=
cache_miss_insertion_duration_ns
;
#ifdef TRITON_ENABLE_METRICS
if
(
metric_reporter
!=
nullptr
)
{
// Add cache insertion time to request duration since insertion
// happens after inference backend sets the request duration, and
// cache lookup time was already included before the inference backend
// was called
metric_reporter
->
MetricInferenceRequestDuration
().
Increment
(
cache_miss_duration_ns
/
1000
);
metric_reporter
->
MetricCacheMissCount
().
Increment
(
1
);
metric_reporter
->
MetricCacheMissLookupDuration
().
Increment
(
cache_miss_lookup_duration_ns
/
1000
);
metric_reporter
->
MetricCacheMissInsertionDuration
().
Increment
(
cache_miss_insertion_duration_ns
/
1000
);
}
#endif // TRITON_ENABLE_METRICS
}
void
InferenceStatsAggregator
::
UpdateInferBatchStats
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_end_ns
,
const
uint64_t
compute_output_start_ns
,
const
uint64_t
compute_end_ns
)
{
auto
compute_input_duration_ns
=
(
compute_input_end_ns
-
compute_start_ns
);
auto
compute_infer_duration_ns
=
(
compute_output_start_ns
-
compute_input_end_ns
);
auto
compute_output_duration_ns
=
(
compute_end_ns
-
compute_output_start_ns
);
UpdateInferBatchStatsWithDuration
(
metric_reporter
,
batch_size
,
compute_input_duration_ns
,
compute_infer_duration_ns
,
compute_output_duration_ns
);
}
void
InferenceStatsAggregator
::
UpdateInferBatchStatsWithDuration
(
MetricModelReporter
*
metric_reporter
,
size_t
batch_size
,
const
uint64_t
compute_input_duration_ns
,
const
uint64_t
compute_infer_duration_ns
,
const
uint64_t
compute_output_duration_ns
)
{
uint64_t
inference_ms
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
())
.
count
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
inference_ms
>
last_inference_ms_
)
{
last_inference_ms_
=
inference_ms
;
}
execution_count_
++
;
auto
it
=
batch_stats_
.
find
(
batch_size
);
if
(
it
==
batch_stats_
.
end
())
{
it
=
batch_stats_
.
emplace
(
batch_size
,
InferBatchStats
()).
first
;
}
it
->
second
.
count_
++
;
it
->
second
.
compute_input_duration_ns_
+=
compute_input_duration_ns
;
it
->
second
.
compute_infer_duration_ns_
+=
compute_infer_duration_ns
;
it
->
second
.
compute_output_duration_ns_
+=
compute_output_duration_ns
;
#ifdef TRITON_ENABLE_METRICS
if
(
metric_reporter
!=
nullptr
)
{
metric_reporter
->
MetricInferenceExecutionCount
().
Increment
(
1
);
}
#endif // TRITON_ENABLE_METRICS
}
#endif // TRITON_ENABLE_STATS
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_stats.h
0 → 100644
View file @
b30f3cdb
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <time.h>
#include <map>
#include <memory>
#include <mutex>
#include <vector>
#include "constants.h"
#include "infer_response.h"
#include "status.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
class
MetricModelReporter
;
//
// InferenceStatsAggregator
//
// A statistics aggregator.
//
class
InferenceStatsAggregator
{
#ifdef TRITON_ENABLE_STATS
public:
struct
InferStats
{
InferStats
()
:
failure_count_
(
0
),
failure_duration_ns_
(
0
),
success_count_
(
0
),
request_duration_ns_
(
0
),
queue_duration_ns_
(
0
),
compute_input_duration_ns_
(
0
),
compute_infer_duration_ns_
(
0
),
compute_output_duration_ns_
(
0
),
cache_hit_count_
(
0
),
cache_hit_lookup_duration_ns_
(
0
),
cache_miss_count_
(
0
),
cache_miss_lookup_duration_ns_
(
0
),
cache_miss_insertion_duration_ns_
(
0
)
{
}
uint64_t
failure_count_
;
uint64_t
failure_duration_ns_
;
uint64_t
success_count_
;
uint64_t
request_duration_ns_
;
uint64_t
queue_duration_ns_
;
uint64_t
compute_input_duration_ns_
;
uint64_t
compute_infer_duration_ns_
;
uint64_t
compute_output_duration_ns_
;
// Cache hit stats
uint64_t
cache_hit_count_
;
uint64_t
cache_hit_lookup_duration_ns_
;
// Cache miss stats
uint64_t
cache_miss_count_
;
uint64_t
cache_miss_lookup_duration_ns_
;
uint64_t
cache_miss_insertion_duration_ns_
;
};
struct
InferBatchStats
{
InferBatchStats
()
:
count_
(
0
),
compute_input_duration_ns_
(
0
),
compute_infer_duration_ns_
(
0
),
compute_output_duration_ns_
(
0
)
{
}
uint64_t
count_
;
uint64_t
compute_input_duration_ns_
;
uint64_t
compute_infer_duration_ns_
;
uint64_t
compute_output_duration_ns_
;
};
// Create an aggregator for model statistics
InferenceStatsAggregator
()
:
last_inference_ms_
(
0
),
inference_count_
(
0
),
execution_count_
(
0
)
{
}
uint64_t
LastInferenceMs
()
const
{
return
last_inference_ms_
;
}
uint64_t
InferenceCount
()
const
{
return
inference_count_
;
}
uint64_t
ExecutionCount
()
const
{
return
execution_count_
;
}
const
InferStats
&
ImmutableInferStats
()
const
{
return
infer_stats_
;
}
const
std
::
map
<
size_t
,
InferBatchStats
>&
ImmutableInferBatchStats
()
const
{
return
batch_stats_
;
}
// Add durations to Infer stats for a failed inference request.
void
UpdateFailure
(
MetricModelReporter
*
metric_reporter
,
const
uint64_t
request_start_ns
,
const
uint64_t
request_end_ns
);
// Add durations to infer stats for a successful inference request.
void
UpdateSuccess
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
request_start_ns
,
const
uint64_t
queue_start_ns
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_end_ns
,
const
uint64_t
compute_output_start_ns
,
const
uint64_t
compute_end_ns
,
const
uint64_t
request_end_ns
);
// Add durations to infer stats for a successful inference request.
void
UpdateSuccessWithDuration
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
request_start_ns
,
const
uint64_t
queue_start_ns
,
const
uint64_t
compute_start_ns
,
const
uint64_t
request_end_ns
,
const
uint64_t
compute_input_duration_ns
,
const
uint64_t
compute_infer_duration_ns
,
const
uint64_t
compute_output_duration_ns
);
// Add durations to infer stats for a successful cached response.
void
UpdateSuccessCacheHit
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
request_start_ns
,
const
uint64_t
queue_start_ns
,
const
uint64_t
cache_lookup_start_ns
,
const
uint64_t
request_end_ns
,
const
uint64_t
cache_hit_lookup_duration_ns
);
// Add durations to infer stats for a cache miss and update request duration
// to account for cache insertion after backend computes the response.
void
UpdateSuccessCacheMiss
(
MetricModelReporter
*
metric_reporter
,
const
uint64_t
cache_miss_lookup_duration_ns
,
const
uint64_t
cache_miss_insertion_duration_ns
);
// Add durations to batch infer stats for a batch execution.
// 'success_request_count' is the number of sucess requests in the
// batch that have infer_stats attached.
void
UpdateInferBatchStats
(
MetricModelReporter
*
metric_reporter
,
const
size_t
batch_size
,
const
uint64_t
compute_start_ns
,
const
uint64_t
compute_input_end_ns
,
const
uint64_t
compute_output_start_ns
,
const
uint64_t
compute_end_ns
);
// Add durations to batch infer stats for a batch execution.
// 'success_request_count' is the number of sucess requests in the
// batch that have infer_stats attached.
void
UpdateInferBatchStatsWithDuration
(
MetricModelReporter
*
metric_reporter
,
size_t
batch_size
,
const
uint64_t
compute_input_duration_ns
,
const
uint64_t
compute_infer_duration_ns
,
const
uint64_t
compute_output_duration_ns
);
private:
std
::
mutex
mu_
;
uint64_t
last_inference_ms_
;
uint64_t
inference_count_
;
uint64_t
execution_count_
;
InferStats
infer_stats_
;
std
::
map
<
size_t
,
InferBatchStats
>
batch_stats_
;
#endif // TRITON_ENABLE_STATS
};
//
// Macros to set infer stats.
//
#ifdef TRITON_ENABLE_STATS
#define INFER_STATS_SET_TIMESTAMP(TS_NS) \
{ \
TS_NS = std::chrono::duration_cast<std::chrono::nanoseconds>( \
std::chrono::steady_clock::now().time_since_epoch()) \
.count(); \
}
#define INFER_STATS_DECL_TIMESTAMP(TS_NS) \
uint64_t TS_NS; \
INFER_STATS_SET_TIMESTAMP(TS_NS);
#else
#define INFER_STATS_DECL_TIMESTAMP(TS_NS)
#define INFER_STATS_SET_TIMESTAMP(TS_NS)
#endif // TRITON_ENABLE_STATS
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_trace.cc
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_trace.h"
namespace
triton
{
namespace
core
{
#ifdef TRITON_ENABLE_TRACING
// Start the trace id at 1, because id 0 is reserved to indicate no
// parent.
std
::
atomic
<
uint64_t
>
InferenceTrace
::
next_id_
(
1
);
InferenceTrace
*
InferenceTrace
::
SpawnChildTrace
()
{
InferenceTrace
*
trace
=
new
InferenceTrace
(
level_
,
id_
,
activity_fn_
,
tensor_activity_fn_
,
release_fn_
,
userp_
);
return
trace
;
}
void
InferenceTrace
::
Release
()
{
release_fn_
(
reinterpret_cast
<
TRITONSERVER_InferenceTrace
*>
(
this
),
userp_
);
}
std
::
shared_ptr
<
InferenceTraceProxy
>
InferenceTraceProxy
::
SpawnChildTrace
()
{
std
::
shared_ptr
<
InferenceTraceProxy
>
strace_proxy
=
std
::
make_shared
<
InferenceTraceProxy
>
(
trace_
->
SpawnChildTrace
());
return
strace_proxy
;
}
#endif // TRITON_ENABLE_TRACING
}}
// namespace triton::core
3rdparty/core-r22.12/src/infer_trace.h
0 → 100644
View file @
b30f3cdb
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <atomic>
#include <chrono>
#include <memory>
#include "constants.h"
#include "status.h"
#include "tritonserver_apis.h"
namespace
triton
{
namespace
core
{
#ifdef TRITON_ENABLE_TRACING
//
// InferenceTrace
//
// Interface to TRITONSERVER_InferenceTrace to report trace events.
//
class
InferenceTrace
{
public:
InferenceTrace
(
const
TRITONSERVER_InferenceTraceLevel
level
,
const
uint64_t
parent_id
,
TRITONSERVER_InferenceTraceActivityFn_t
activity_fn
,
TRITONSERVER_InferenceTraceTensorActivityFn_t
tensor_activity_fn
,
TRITONSERVER_InferenceTraceReleaseFn_t
release_fn
,
void
*
userp
)
:
level_
(
level
),
id_
(
next_id_
++
),
parent_id_
(
parent_id
),
activity_fn_
(
activity_fn
),
tensor_activity_fn_
(
tensor_activity_fn
),
release_fn_
(
release_fn
),
userp_
(
userp
)
{
}
InferenceTrace
*
SpawnChildTrace
();
int64_t
Id
()
const
{
return
id_
;
}
int64_t
ParentId
()
const
{
return
parent_id_
;
}
const
std
::
string
&
ModelName
()
const
{
return
model_name_
;
}
int64_t
ModelVersion
()
const
{
return
model_version_
;
}
void
SetModelName
(
const
std
::
string
&
n
)
{
model_name_
=
n
;
}
void
SetModelVersion
(
int64_t
v
)
{
model_version_
=
v
;
}
// Report trace activity.
void
Report
(
const
TRITONSERVER_InferenceTraceActivity
activity
,
uint64_t
timestamp_ns
)
{
if
((
level_
&
TRITONSERVER_TRACE_LEVEL_TIMESTAMPS
)
>
0
)
{
activity_fn_
(
reinterpret_cast
<
TRITONSERVER_InferenceTrace
*>
(
this
),
activity
,
timestamp_ns
,
userp_
);
}
}
// Report trace activity at the current time.
void
ReportNow
(
const
TRITONSERVER_InferenceTraceActivity
activity
)
{
if
((
level_
&
TRITONSERVER_TRACE_LEVEL_TIMESTAMPS
)
>
0
)
{
Report
(
activity
,
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
());
}
}
// Report tensor trace activity.
void
ReportTensor
(
const
TRITONSERVER_InferenceTraceActivity
activity
,
const
char
*
name
,
TRITONSERVER_DataType
datatype
,
const
void
*
base
,
size_t
byte_size
,
const
int64_t
*
shape
,
uint64_t
dim_count
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
if
((
level_
&
TRITONSERVER_TRACE_LEVEL_TENSORS
)
>
0
)
{
tensor_activity_fn_
(
reinterpret_cast
<
TRITONSERVER_InferenceTrace
*>
(
this
),
activity
,
name
,
datatype
,
base
,
byte_size
,
shape
,
dim_count
,
memory_type
,
memory_type_id
,
userp_
);
}
}
// Release the trace. Call the trace release callback.
void
Release
();
private:
const
TRITONSERVER_InferenceTraceLevel
level_
;
const
uint64_t
id_
;
const
uint64_t
parent_id_
;
TRITONSERVER_InferenceTraceActivityFn_t
activity_fn_
;
TRITONSERVER_InferenceTraceTensorActivityFn_t
tensor_activity_fn_
;
TRITONSERVER_InferenceTraceReleaseFn_t
release_fn_
;
void
*
userp_
;
std
::
string
model_name_
;
int64_t
model_version_
;
// Maintain next id statically so that trace id is unique even
// across traces
static
std
::
atomic
<
uint64_t
>
next_id_
;
};
//
// InferenceTraceProxy
//
// Object attached as shared_ptr to InferenceRequest and
// InferenceResponse(s) being traced as part of a single inference
// request.
//
class
InferenceTraceProxy
{
public:
InferenceTraceProxy
(
InferenceTrace
*
trace
)
:
trace_
(
trace
)
{}
~
InferenceTraceProxy
()
{
trace_
->
Release
();
}
int64_t
Id
()
const
{
return
trace_
->
Id
();
}
int64_t
ParentId
()
const
{
return
trace_
->
ParentId
();
}
const
std
::
string
&
ModelName
()
const
{
return
trace_
->
ModelName
();
}
int64_t
ModelVersion
()
const
{
return
trace_
->
ModelVersion
();
}
void
SetModelName
(
const
std
::
string
&
n
)
{
trace_
->
SetModelName
(
n
);
}
void
SetModelVersion
(
int64_t
v
)
{
trace_
->
SetModelVersion
(
v
);
}
void
Report
(
const
TRITONSERVER_InferenceTraceActivity
activity
,
uint64_t
timestamp_ns
)
{
trace_
->
Report
(
activity
,
timestamp_ns
);
}
void
ReportNow
(
const
TRITONSERVER_InferenceTraceActivity
activity
)
{
trace_
->
ReportNow
(
activity
);
}
void
ReportTensor
(
const
TRITONSERVER_InferenceTraceActivity
activity
,
const
char
*
name
,
TRITONSERVER_DataType
datatype
,
const
void
*
base
,
size_t
byte_size
,
const
int64_t
*
shape
,
uint64_t
dim_count
,
TRITONSERVER_MemoryType
memory_type
,
int64_t
memory_type_id
)
{
trace_
->
ReportTensor
(
activity
,
name
,
datatype
,
base
,
byte_size
,
shape
,
dim_count
,
memory_type
,
memory_type_id
);
}
std
::
shared_ptr
<
InferenceTraceProxy
>
SpawnChildTrace
();
private:
InferenceTrace
*
trace_
;
};
#endif // TRITON_ENABLE_TRACING
//
// Macros to generate trace activity
//
#ifdef TRITON_ENABLE_TRACING
#define INFER_TRACE_ACTIVITY(T, A, TS_NS) \
{ \
const auto& trace = (T); \
const auto ts_ns = (TS_NS); \
if (trace != nullptr) { \
trace->Report(A, ts_ns); \
} \
}
#define INFER_TRACE_ACTIVITY_NOW(T, A) \
{ \
const auto& trace = (T); \
if (trace != nullptr) { \
trace->ReportNow(A); \
} \
}
#define INFER_TRACE_TENSOR_ACTIVITY(T, A, N, D, BA, BY, S, DI, MT, MTI) \
{ \
const auto& trace = (T); \
if (trace != nullptr) { \
trace->ReportTensor(A, N, D, BA, BY, S, DI, MT, MTI); \
} \
}
#else
#define INFER_TRACE_ACTIVITY(T, A, TS_NS)
#define INFER_TRACE_ACTIVITY_NOW(T, A)
#define INFER_TRACE_TENSOR_ACTIVITY(T, A, N, D, BA, BY, S, DI, MT, MTI)
#endif // TRITON_ENABLE_TRACING
}}
// namespace triton::core
3rdparty/core-r22.12/src/instance_queue.cc
0 → 100644
View file @
b30f3cdb
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "instance_queue.h"
#include "triton/common/logging.h"
namespace
triton
{
namespace
core
{
InstanceQueue
::
InstanceQueue
(
size_t
max_batch_size
,
uint64_t
max_queue_delay_ns
)
:
max_batch_size_
(
max_batch_size
),
max_queue_delay_ns_
(
max_queue_delay_ns
)
{
}
size_t
InstanceQueue
::
Size
()
{
return
payload_queue_
.
size
();
}
bool
InstanceQueue
::
Empty
()
{
return
payload_queue_
.
empty
();
}
void
InstanceQueue
::
Enqueue
(
const
std
::
shared_ptr
<
Payload
>&
payload
)
{
payload_queue_
.
push_back
(
payload
);
}
void
InstanceQueue
::
Dequeue
(
std
::
shared_ptr
<
Payload
>*
payload
,
std
::
vector
<
std
::
shared_ptr
<
Payload
>>*
merged_payloads
)
{
*
payload
=
payload_queue_
.
front
();
payload_queue_
.
pop_front
();
{
std
::
lock_guard
<
std
::
mutex
>
exec_lock
(
*
((
*
payload
)
->
GetExecMutex
()));
(
*
payload
)
->
SetState
(
Payload
::
State
::
EXECUTING
);
if
((
!
payload_queue_
.
empty
())
&&
(
max_queue_delay_ns_
>
0
)
&&
(
max_batch_size_
>
1
)
&&
(
!
(
*
payload
)
->
IsSaturated
()))
{
bool
continue_merge
;
do
{
continue_merge
=
false
;
uint64_t
now_ns
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
())
.
count
();
size_t
batch_size
=
(
*
payload
)
->
BatchSize
();
if
((
!
payload_queue_
.
empty
())
&&
(
!
payload_queue_
.
front
()
->
IsSaturated
())
&&
(
now_ns
-
payload_queue_
.
front
()
->
BatcherStartNs
())
>
max_queue_delay_ns_
)
{
std
::
lock_guard
<
std
::
mutex
>
exec_lock
(
*
(
payload_queue_
.
front
()
->
GetExecMutex
()));
payload_queue_
.
front
()
->
SetState
(
Payload
::
State
::
EXECUTING
);
size_t
front_batch_size
=
payload_queue_
.
front
()
->
BatchSize
();
if
((
batch_size
+
front_batch_size
)
<=
max_batch_size_
)
{
const
auto
&
status
=
(
*
payload
)
->
MergePayload
(
payload_queue_
.
front
());
if
(
status
.
IsOk
())
{
merged_payloads
->
push_back
(
payload_queue_
.
front
());
payload_queue_
.
pop_front
();
continue_merge
=
true
;
}
}
}
}
while
(
continue_merge
);
}
}
}
}}
// namespace triton::core
3rdparty/core-r22.12/src/instance_queue.h
0 → 100644
View file @
b30f3cdb
// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "payload.h"
namespace
triton
{
namespace
core
{
//
// InstanceQueue
//
// A queue implementation holding Payloads ready to be scheduled on
// model instance.
class
InstanceQueue
{
public:
explicit
InstanceQueue
(
size_t
max_batch_size
,
uint64_t
max_queue_delay_ns
);
size_t
Size
();
bool
Empty
();
void
Enqueue
(
const
std
::
shared_ptr
<
Payload
>&
payload
);
void
Dequeue
(
std
::
shared_ptr
<
Payload
>*
payload
,
std
::
vector
<
std
::
shared_ptr
<
Payload
>>*
merged_payloads
);
private:
size_t
max_batch_size_
;
uint64_t
max_queue_delay_ns_
;
std
::
deque
<
std
::
shared_ptr
<
Payload
>>
payload_queue_
;
std
::
shared_ptr
<
Payload
>
staged_payload_
;
std
::
mutex
mu_
;
};
}}
// namespace triton::core
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment