Commit 502f4fb9 authored by limm's avatar limm
Browse files

add tools and service module

parent 68661967
Pipeline #2809 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: inference.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0finference.proto\x12\x08mmdeploy\"\x91\x01\n\x05Model\x12\x11\n\x04name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x0f\n\x07weights\x18\x02 \x01(\x0c\x12+\n\x06\x64\x65vice\x18\x03 \x01(\x0e\x32\x16.mmdeploy.Model.DeviceH\x01\x88\x01\x01\"#\n\x06\x44\x65vice\x12\x07\n\x03\x43PU\x10\x00\x12\x07\n\x03GPU\x10\x01\x12\x07\n\x03\x44SP\x10\x02\x42\x07\n\x05_nameB\t\n\x07_device\"\x07\n\x05\x45mpty\"Q\n\x06Tensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\x05\x64type\x18\x02 \x01(\tH\x00\x88\x01\x01\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\r\n\x05shape\x18\x04 \x03(\x05\x42\x08\n\x06_dtype\",\n\nTensorList\x12\x1e\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x10.mmdeploy.Tensor\"E\n\x05Reply\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0c\n\x04info\x18\x02 \x01(\t\x12\x1e\n\x04\x64\x61ta\x18\x03 \x03(\x0b\x32\x10.mmdeploy.Tensor\"\x16\n\x05Names\x12\r\n\x05names\x18\x01 \x03(\t2\xfb\x01\n\tInference\x12*\n\x04\x45\x63ho\x12\x0f.mmdeploy.Empty\x1a\x0f.mmdeploy.Reply\"\x00\x12*\n\x04Init\x12\x0f.mmdeploy.Model\x1a\x0f.mmdeploy.Reply\"\x00\x12\x31\n\x0bOutputNames\x12\x0f.mmdeploy.Empty\x1a\x0f.mmdeploy.Names\"\x00\x12\x34\n\tInference\x12\x14.mmdeploy.TensorList\x1a\x0f.mmdeploy.Reply\"\x00\x12-\n\x07\x44\x65stroy\x12\x0f.mmdeploy.Empty\x1a\x0f.mmdeploy.Reply\"\x00\x42%\n\rmmdeploy.snpeB\x0bSNPEWrapperP\x01\xa2\x02\x04SNPEb\x06proto3'
)
_MODEL = DESCRIPTOR.message_types_by_name['Model']
_EMPTY = DESCRIPTOR.message_types_by_name['Empty']
_TENSOR = DESCRIPTOR.message_types_by_name['Tensor']
_TENSORLIST = DESCRIPTOR.message_types_by_name['TensorList']
_REPLY = DESCRIPTOR.message_types_by_name['Reply']
_NAMES = DESCRIPTOR.message_types_by_name['Names']
_MODEL_DEVICE = _MODEL.enum_types_by_name['Device']
Model = _reflection.GeneratedProtocolMessageType(
'Model',
(_message.Message, ),
{
'DESCRIPTOR': _MODEL,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:mmdeploy.Model)
})
_sym_db.RegisterMessage(Model)
Empty = _reflection.GeneratedProtocolMessageType(
'Empty',
(_message.Message, ),
{
'DESCRIPTOR': _EMPTY,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:mmdeploy.Empty)
})
_sym_db.RegisterMessage(Empty)
Tensor = _reflection.GeneratedProtocolMessageType(
'Tensor',
(_message.Message, ),
{
'DESCRIPTOR': _TENSOR,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:mmdeploy.Tensor)
})
_sym_db.RegisterMessage(Tensor)
TensorList = _reflection.GeneratedProtocolMessageType(
'TensorList',
(_message.Message, ),
{
'DESCRIPTOR': _TENSORLIST,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:mmdeploy.TensorList)
})
_sym_db.RegisterMessage(TensorList)
Reply = _reflection.GeneratedProtocolMessageType(
'Reply',
(_message.Message, ),
{
'DESCRIPTOR': _REPLY,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:mmdeploy.Reply)
})
_sym_db.RegisterMessage(Reply)
Names = _reflection.GeneratedProtocolMessageType(
'Names',
(_message.Message, ),
{
'DESCRIPTOR': _NAMES,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:mmdeploy.Names)
})
_sym_db.RegisterMessage(Names)
_INFERENCE = DESCRIPTOR.services_by_name['Inference']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\rmmdeploy.snpeB\013SNPEWrapperP\001\242\002\004SNPE'
_MODEL._serialized_start = 30
_MODEL._serialized_end = 175
_MODEL_DEVICE._serialized_start = 120
_MODEL_DEVICE._serialized_end = 155
_EMPTY._serialized_start = 177
_EMPTY._serialized_end = 184
_TENSOR._serialized_start = 186
_TENSOR._serialized_end = 267
_TENSORLIST._serialized_start = 269
_TENSORLIST._serialized_end = 313
_REPLY._serialized_start = 315
_REPLY._serialized_end = 384
_NAMES._serialized_start = 386
_NAMES._serialized_end = 408
_INFERENCE._serialized_start = 411
_INFERENCE._serialized_end = 662
# @@protoc_insertion_point(module_scope)
# Copyright (c) OpenMMLab. All rights reserved.
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import inference_pb2 as inference__pb2
class InferenceStub(object):
"""The inference service definition."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Echo = channel.unary_unary(
'/mmdeploy.Inference/Echo',
request_serializer=inference__pb2.Empty.SerializeToString,
response_deserializer=inference__pb2.Reply.FromString,
)
self.Init = channel.unary_unary(
'/mmdeploy.Inference/Init',
request_serializer=inference__pb2.Model.SerializeToString,
response_deserializer=inference__pb2.Reply.FromString,
)
self.OutputNames = channel.unary_unary(
'/mmdeploy.Inference/OutputNames',
request_serializer=inference__pb2.Empty.SerializeToString,
response_deserializer=inference__pb2.Names.FromString,
)
self.Inference = channel.unary_unary(
'/mmdeploy.Inference/Inference',
request_serializer=inference__pb2.TensorList.SerializeToString,
response_deserializer=inference__pb2.Reply.FromString,
)
self.Destroy = channel.unary_unary(
'/mmdeploy.Inference/Destroy',
request_serializer=inference__pb2.Empty.SerializeToString,
response_deserializer=inference__pb2.Reply.FromString,
)
class InferenceServicer(object):
"""The inference service definition."""
def Echo(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Init(self, request, context):
"""Init Model with model file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def OutputNames(self, request, context):
"""Get output names."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Inference(self, request, context):
"""Inference with inputs."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Destroy(self, request, context):
"""Destroy handle."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_InferenceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Echo':
grpc.unary_unary_rpc_method_handler(
servicer.Echo,
request_deserializer=inference__pb2.Empty.FromString,
response_serializer=inference__pb2.Reply.SerializeToString,
),
'Init':
grpc.unary_unary_rpc_method_handler(
servicer.Init,
request_deserializer=inference__pb2.Model.FromString,
response_serializer=inference__pb2.Reply.SerializeToString,
),
'OutputNames':
grpc.unary_unary_rpc_method_handler(
servicer.OutputNames,
request_deserializer=inference__pb2.Empty.FromString,
response_serializer=inference__pb2.Names.SerializeToString,
),
'Inference':
grpc.unary_unary_rpc_method_handler(
servicer.Inference,
request_deserializer=inference__pb2.TensorList.FromString,
response_serializer=inference__pb2.Reply.SerializeToString,
),
'Destroy':
grpc.unary_unary_rpc_method_handler(
servicer.Destroy,
request_deserializer=inference__pb2.Empty.FromString,
response_serializer=inference__pb2.Reply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'mmdeploy.Inference', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler, ))
# This class is part of an EXPERIMENTAL API.
class Inference(object):
"""The inference service definition."""
@staticmethod
def Echo(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request, target, '/mmdeploy.Inference/Echo',
inference__pb2.Empty.SerializeToString,
inference__pb2.Reply.FromString, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout,
metadata)
@staticmethod
def Init(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request, target, '/mmdeploy.Inference/Init',
inference__pb2.Model.SerializeToString,
inference__pb2.Reply.FromString, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout,
metadata)
@staticmethod
def OutputNames(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request, target, '/mmdeploy.Inference/OutputNames',
inference__pb2.Empty.SerializeToString,
inference__pb2.Names.FromString, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout,
metadata)
@staticmethod
def Inference(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request, target, '/mmdeploy.Inference/Inference',
inference__pb2.TensorList.SerializeToString,
inference__pb2.Reply.FromString, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout,
metadata)
@staticmethod
def Destroy(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request, target, '/mmdeploy.Inference/Destroy',
inference__pb2.Empty.SerializeToString,
inference__pb2.Reply.FromString, options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout,
metadata)
syntax = "proto3";
option java_multiple_files = true;
option java_package = "mmdeploy.snpe";
option java_outer_classname = "SNPEWrapper";
option objc_class_prefix = "SNPE";
package mmdeploy;
// The inference service definition.
service Inference {
rpc Echo(Empty) returns (Reply) {}
// Init Model with model file
rpc Init(Model) returns (Reply) {}
// Get output names
rpc OutputNames(Empty) returns (Names) {}
// Inference with inputs
rpc Inference(TensorList) returns (Reply) {}
// Destroy handle
rpc Destroy(Empty) returns (Reply) {}
}
message Model {
optional string name = 1;
// bin
bytes weights = 2;
// config
enum Device {
CPU = 0;
GPU = 1;
DSP = 2;
}
optional Device device = 3;
}
// https://stackoverflow.com/questions/31768665/can-i-define-a-grpc-call-with-a-null-request-or-response
message Empty {}
message Tensor {
// name
string name = 1;
// datatype
optional string dtype = 2;
// data
bytes data = 3;
// shape
repeated int32 shape = 4;
}
message TensorList {
repeated Tensor data = 1;
}
message Reply {
int32 status = 1;
string info = 2;
repeated Tensor data = 3;
}
message Names {
repeated string names = 1;
}
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cmake build file for C++ helloworld example.
# Assumes protobuf and gRPC have been installed using cmake.
# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build
# that automatically builds all the dependencies before building helloworld.
cmake_minimum_required(VERSION 3.5.1)
project(SNPEServer C CXX)
include(./common.cmake)
# Proto file
get_filename_component(hw_proto "../inference.proto" ABSOLUTE)
get_filename_component(hw_proto_path "${hw_proto}" PATH)
# Generated sources
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/inference.pb.cc")
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/inference.pb.h")
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/inference.grpc.pb.cc")
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/inference.grpc.pb.h")
add_custom_command(
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
COMMAND ${_PROTOBUF_PROTOC}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
-I "${hw_proto_path}"
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
"${hw_proto}"
DEPENDS "${hw_proto}")
# Include generated *.pb.h files
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
# hw_grpc_proto
add_library(hw_grpc_proto
${hw_grpc_srcs}
${hw_grpc_hdrs}
${hw_proto_srcs}
${hw_proto_hdrs})
target_link_libraries(hw_grpc_proto
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
add_library(snpe SHARED IMPORTED)
if (NOT EXISTS $ENV{SNPE_ROOT}/lib/aarch64-android-clang6.0/)
message(FATAL_ERROR "SNPE_ROOT directory not exist: "$ENV{SNPE_ROOT}/lib/aarch64-android-clang6.0/)
endif()
set_target_properties(snpe PROPERTIES
IMPORTED_LOCATION "$ENV{SNPE_ROOT}/lib/aarch64-android-clang6.0/libSNPE.so"
INTERFACE_INCLUDE_DIRECTORIES "$ENV{SNPE_ROOT}/include/zdl"
)
target_link_directories(
snpe
INTERFACE
)
add_executable(inference_server inference_server.cc service_impl.cpp)
target_link_libraries(inference_server
hw_grpc_proto
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF}
snpe)
# Copyright 2018 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cmake build file for C++ route_guide example.
# Assumes protobuf and gRPC have been installed using cmake.
# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build
# that automatically builds all the dependencies before building route_guide.
cmake_minimum_required(VERSION 3.5.1)
set (CMAKE_CXX_STANDARD 17)
if(MSVC)
add_definitions(-D_WIN32_WINNT=0x600)
endif()
find_package(Threads REQUIRED)
if(GRPC_AS_SUBMODULE)
# One way to build a projects that uses gRPC is to just include the
# entire gRPC project tree via "add_subdirectory".
# This approach is very simple to use, but the are some potential
# disadvantages:
# * it includes gRPC's CMakeLists.txt directly into your build script
# without and that can make gRPC's internal setting interfere with your
# own build.
# * depending on what's installed on your system, the contents of submodules
# in gRPC's third_party/* might need to be available (and there might be
# additional prerequisites required to build them). Consider using
# the gRPC_*_PROVIDER options to fine-tune the expected behavior.
#
# A more robust approach to add dependency on gRPC is using
# cmake's ExternalProject_Add (see cmake_externalproject/CMakeLists.txt).
# Include the gRPC's cmake build (normally grpc source code would live
# in a git submodule called "third_party/grpc", but this example lives in
# the same repository as gRPC sources, so we just look a few directories up)
add_subdirectory(../../.. ${CMAKE_CURRENT_BINARY_DIR}/grpc EXCLUDE_FROM_ALL)
message(STATUS "Using gRPC via add_subdirectory.")
# After using add_subdirectory, we can now use the grpc targets directly from
# this build.
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
set(_REFLECTION grpc++_reflection)
if(CMAKE_CROSSCOMPILING)
find_program(_PROTOBUF_PROTOC protoc)
else()
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
endif()
set(_GRPC_GRPCPP grpc++)
if(CMAKE_CROSSCOMPILING)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
else()
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:grpc_cpp_plugin>)
endif()
elseif(GRPC_FETCHCONTENT)
# Another way is to use CMake's FetchContent module to clone gRPC at
# configure time. This makes gRPC's source code available to your project,
# similar to a git submodule.
message(STATUS "Using gRPC via add_subdirectory (FetchContent).")
include(FetchContent)
FetchContent_Declare(
grpc
GIT_REPOSITORY https://github.com/grpc/grpc.git
# when using gRPC, you will actually set this to an existing tag, such as
# v1.25.0, v1.26.0 etc..
# For the purpose of testing, we override the tag used to the commit
# that's currently under test.
GIT_TAG vGRPC_TAG_VERSION_OF_YOUR_CHOICE)
FetchContent_MakeAvailable(grpc)
# Since FetchContent uses add_subdirectory under the hood, we can use
# the grpc targets directly from this build.
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
set(_REFLECTION grpc++_reflection)
set(_PROTOBUF_PROTOC $<TARGET_FILE:protoc>)
set(_GRPC_GRPCPP grpc++)
if(CMAKE_CROSSCOMPILING)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
else()
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:grpc_cpp_plugin>)
endif()
else()
# This branch assumes that gRPC and all its dependencies are already installed
# on this system, so they can be located by find_package().
# Find Protobuf installation
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
set(protobuf_MODULE_COMPATIBLE TRUE)
find_package(Protobuf CONFIG REQUIRED)
message(STATUS "Using protobuf ${Protobuf_VERSION}")
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
set(_REFLECTION gRPC::grpc++_reflection)
if(CMAKE_CROSSCOMPILING)
find_program(_PROTOBUF_PROTOC protoc)
else()
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
endif()
# Find gRPC installation
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
find_package(gRPC CONFIG REQUIRED)
message(STATUS "Using gRPC ${gRPC_VERSION}")
set(_GRPC_GRPCPP gRPC::grpc++)
if(CMAKE_CROSSCOMPILING)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
else()
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
endif()
endif()
/*
*
* Copyright 2015 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Copyright (c) OpenMMLab. All rights reserved.
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <netinet/in.h>
#include <stdio.h>
#include <string.h>
#include <sys/types.h>
#include <iostream>
#include "service_impl.h"
#include "text_table.h"
void PrintIP() {
struct ifaddrs* ifAddrStruct = NULL;
void* tmpAddrPtr = NULL;
int retval = getifaddrs(&ifAddrStruct);
if (retval == -1) {
return;
}
helper::TextTable table("Device");
table.padding(1);
table.add("port").add("ip").eor();
while (ifAddrStruct != nullptr) {
if (ifAddrStruct->ifa_addr == nullptr) {
break;
}
if (ifAddrStruct->ifa_addr->sa_family == AF_INET) {
tmpAddrPtr = &((struct sockaddr_in*)ifAddrStruct->ifa_addr)->sin_addr;
char addressBuffer[INET_ADDRSTRLEN];
inet_ntop(AF_INET, tmpAddrPtr, addressBuffer, INET_ADDRSTRLEN);
table.add(std::string(ifAddrStruct->ifa_name)).add(std::string(addressBuffer)).eor();
} else if (ifAddrStruct->ifa_addr->sa_family == AF_INET6) {
tmpAddrPtr = &((struct sockaddr_in*)ifAddrStruct->ifa_addr)->sin_addr;
char addressBuffer[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, tmpAddrPtr, addressBuffer, INET6_ADDRSTRLEN);
table.add(std::string(ifAddrStruct->ifa_name)).add(std::string(addressBuffer)).eor();
}
ifAddrStruct = ifAddrStruct->ifa_next;
}
std::cout << table << std::endl << std::endl;
}
void RunServer(int port = 60000) {
// listen IPv4 and IPv6
char server_address[64] = {0};
sprintf(server_address, "[::]:%d", port);
InferenceServiceImpl service;
grpc::EnableDefaultHealthCheckService(true);
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
ServerBuilder builder;
// Listen on the given address without any authentication mechanism.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Max 128MB
builder.SetMaxMessageSize(2 << 29);
builder.SetMaxSendMessageSize(2 << 29);
// Register "service" as the instance through which we'll communicate with
// clients. In this case it corresponds to an *synchronous* service.
builder.RegisterService(&service);
// Finally assemble the server.
std::unique_ptr<Server> server(builder.BuildAndStart());
fprintf(stdout, "Server listening on %s\n", server_address);
// Wait for the server to shutdown. Note that some other thread must be
// responsible for shutting down the server for this call to ever return.
server->Wait();
}
int main(int argc, char** argv) {
int port = 60000;
if (argc > 1) {
port = std::stoi(argv[1]);
}
if (port <= 9999) {
fprintf(stdout, "Usage: %s [port]\n", argv[0]);
return 0;
}
PrintIP();
RunServer(port);
return 0;
}
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <sys/time.h>
#include <cstdio>
#include <memory>
#include <string>
class ScopeTimer {
public:
ScopeTimer(std::string _name, bool _print = false) : name(_name), print(_print) { begin = now(); }
~ScopeTimer() {
if (!print) {
return;
}
fprintf(stdout, "%s: %ldms\n", name.c_str(), (now() - begin));
}
long now() const {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000 + (tv.tv_usec / 1000);
}
long cost() const { return now() - begin; }
private:
std::string name;
bool print;
long begin;
};
// Copyright (c) OpenMMLab. All rights reserved.
#include "service_impl.h"
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <iterator>
#include <string>
#include <unordered_map>
#include <vector>
#include "scope_timer.h"
#include "text_table.h"
zdl::DlSystem::Runtime_t InferenceServiceImpl::CheckRuntime(zdl::DlSystem::Runtime_t runtime,
bool& staticQuantization) {
static zdl::DlSystem::Version_t Version = zdl::SNPE::SNPEFactory::getLibraryVersion();
fprintf(stdout, "SNPE Version: %s\n", Version.asString().c_str());
if ((runtime != zdl::DlSystem::Runtime_t::DSP) && staticQuantization) {
fprintf(stderr,
"ERROR: Cannot use static quantization with CPU/GPU runtimes. "
"It is only designed for DSP/AIP runtimes.\n"
"ERROR: Proceeding without static quantization on selected "
"runtime.\n");
staticQuantization = false;
}
if (!zdl::SNPE::SNPEFactory::isRuntimeAvailable(runtime)) {
fprintf(stderr, "Selected runtime not present. Falling back to CPU.\n");
runtime = zdl::DlSystem::Runtime_t::CPU;
}
return runtime;
}
void InferenceServiceImpl::Build(std::unique_ptr<zdl::DlContainer::IDlContainer>& container,
zdl::DlSystem::Runtime_t runtime,
zdl::DlSystem::RuntimeList runtimeList,
bool useUserSuppliedBuffers,
zdl::DlSystem::PlatformConfig platformConfig) {
zdl::SNPE::SNPEBuilder snpeBuilder(container.get());
if (runtimeList.empty()) {
runtimeList.add(runtime);
}
snpe = snpeBuilder.setOutputLayers({})
.setRuntimeProcessorOrder(runtimeList)
.setUseUserSuppliedBuffers(useUserSuppliedBuffers)
.setPlatformConfig(platformConfig)
.setExecutionPriorityHint(zdl::DlSystem::ExecutionPriorityHint_t::HIGH)
.setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::SUSTAINED_HIGH_PERFORMANCE)
.build();
return;
}
void InferenceServiceImpl::SaveDLC(const ::mmdeploy::Model* request, const std::string& filename) {
auto model = request->weights();
fprintf(stdout, "saving file to %s\n", filename.c_str());
std::ofstream fout;
fout.open(filename, std::ios::binary | std::ios::out);
fout.write(model.data(), model.size());
fout.flush();
fout.close();
}
void InferenceServiceImpl::LoadFloatData(const std::string& data, std::vector<float>& vec) {
size_t len = data.size();
assert(len % sizeof(float) == 0);
const char* ptr = data.data();
for (int i = 0; i < len; i += sizeof(float)) {
vec.push_back(*(float*)(ptr + i));
}
}
::grpc::Status InferenceServiceImpl::Echo(::grpc::ServerContext* context,
const ::mmdeploy::Empty* request,
::mmdeploy::Reply* response) {
response->set_info("echo");
return Status::OK;
}
// Logic and data behind the server's behavior.
::grpc::Status InferenceServiceImpl::Init(::grpc::ServerContext* context,
const ::mmdeploy::Model* request,
::mmdeploy::Reply* response) {
zdl::SNPE::SNPEFactory::initializeLogging(zdl::DlSystem::LogLevel_t::LOG_ERROR);
zdl::SNPE::SNPEFactory::setLogLevel(zdl::DlSystem::LogLevel_t::LOG_ERROR);
if (snpe != nullptr) {
snpe.reset();
}
if (container != nullptr) {
container.reset();
}
auto model = request->weights();
container =
zdl::DlContainer::IDlContainer::open(reinterpret_cast<uint8_t*>(model.data()), model.size());
if (container == nullptr) {
fprintf(stdout, "Stage Init: load dlc failed.\n");
response->set_status(-1);
response->set_info(zdl::DlSystem::getLastErrorString());
return Status::OK;
}
fprintf(stdout, "Stage Init: load dlc success.\n");
zdl::DlSystem::Runtime_t runtime = zdl::DlSystem::Runtime_t::GPU;
if (request->has_device()) {
switch (request->device()) {
case mmdeploy::Model_Device_GPU:
runtime = zdl::DlSystem::Runtime_t::GPU;
break;
case mmdeploy::Model_Device_DSP:
runtime = zdl::DlSystem::Runtime_t::DSP;
default:
break;
}
}
if (runtime != zdl::DlSystem::Runtime_t::CPU) {
bool static_quant = false;
runtime = CheckRuntime(runtime, static_quant);
}
zdl::DlSystem::RuntimeList runtimeList;
runtimeList.add(zdl::DlSystem::Runtime_t::CPU);
runtimeList.add(runtime);
zdl::DlSystem::PlatformConfig platformConfig;
{
ScopeTimer timer("build snpe");
Build(container, runtime, runtimeList, false, platformConfig);
}
if (snpe == nullptr) {
response->set_status(-1);
response->set_info(zdl::DlSystem::getLastErrorString());
}
// setup logger
auto logger_opt = snpe->getDiagLogInterface();
if (!logger_opt) throw std::runtime_error("SNPE failed to obtain logging interface");
auto logger = *logger_opt;
auto opts = logger->getOptions();
static std::string OutputDir = "./output/";
opts.LogFileDirectory = OutputDir;
if (!logger->setOptions(opts)) {
std::cerr << "Failed to set options" << std::endl;
return Status::OK;
}
if (!logger->start()) {
std::cerr << "Failed to start logger" << std::endl;
return Status::OK;
}
const auto& inputTensorNamesRef = snpe->getInputTensorNames();
const auto& inputTensorNames = *inputTensorNamesRef;
inputTensors.resize(inputTensorNames.size());
for (int i = 0; i < inputTensorNames.size(); ++i) {
const char* pname = inputTensorNames.at(i);
const auto& shape_opt = snpe->getInputDimensions(pname);
const auto& shape = *shape_opt;
fprintf(stdout, "Stage Init: input tensor info:\n");
switch (shape.rank()) {
case 1:
fprintf(stdout, "name: %s, shape: [%ld]\n", pname, shape[0]);
break;
case 2:
fprintf(stdout, "name: %s, shape: [%ld,%ld]\n", pname, shape[0], shape[1]);
break;
case 3:
fprintf(stdout, "name: %s, shape: [%ld,%ld,%ld]\n", pname, shape[0], shape[1], shape[2]);
break;
case 4:
fprintf(stdout, "name: %s, shape: [%ld,%ld,%ld,%ld]\n", pname, shape[0], shape[1], shape[2],
shape[3]);
break;
}
inputTensors[i] = zdl::SNPE::SNPEFactory::getTensorFactory().createTensor(shape);
inputTensorMap.add(pname, inputTensors[i].get());
}
response->set_status(0);
response->set_info("Stage Init: success");
return Status::OK;
}
std::string InferenceServiceImpl::ContentStr(zdl::DlSystem::ITensor* pTensor) {
std::string str;
const size_t N = std::min(5UL, pTensor->getSize());
auto it = pTensor->cbegin();
for (int i = 0; i < N; ++i) {
str += std::to_string(*(it + i));
str += " ";
}
str += "..";
str += std::to_string(*(it + pTensor->getSize() - 1));
return str;
}
std::string InferenceServiceImpl::ShapeStr(zdl::DlSystem::ITensor* pTensor) {
std::string str;
str += "[";
auto shape = pTensor->getShape();
for (int i = 0; i < shape.rank(); ++i) {
str += std::to_string(shape[i]);
str += ",";
}
str += ']';
return str;
}
::grpc::Status InferenceServiceImpl::OutputNames(::grpc::ServerContext* context,
const ::mmdeploy::Empty* request,
::mmdeploy::Names* response) {
const auto& outputTensorNamesRef = snpe->getOutputTensorNames();
const auto& outputTensorNames = *outputTensorNamesRef;
for (int i = 0; i < outputTensorNames.size(); ++i) {
response->add_names(outputTensorNames.at(i));
}
return Status::OK;
}
::grpc::Status InferenceServiceImpl::Inference(::grpc::ServerContext* context,
const ::mmdeploy::TensorList* request,
::mmdeploy::Reply* response) {
// Get input names and number
const auto& inputTensorNamesRef = snpe->getInputTensorNames();
if (!inputTensorNamesRef) {
response->set_status(-1);
response->set_info(zdl::DlSystem::getLastErrorString());
return Status::OK;
}
const auto& inputTensorNames = *inputTensorNamesRef;
if (inputTensorNames.size() != request->data_size()) {
response->set_status(-1);
response->set_info("Stage Inference: input names count not match !");
return Status::OK;
}
helper::TextTable table("Inference");
table.padding(1);
table.add("type").add("name").add("shape").add("content").eor();
// Load input/output buffers with TensorMap
{
// ScopeTimer timer("convert input");
for (int i = 0; i < request->data_size(); ++i) {
auto tensor = request->data(i);
std::vector<float> float_input;
LoadFloatData(tensor.data(), float_input);
zdl::DlSystem::ITensor* ptensor = inputTensorMap.getTensor(tensor.name().c_str());
if (ptensor == nullptr) {
fprintf(stderr, "Stage Inference: name: %s not existed in input tensor map\n",
tensor.name().c_str());
response->set_status(-1);
response->set_info("cannot find name in input tensor map.");
return Status::OK;
}
if (float_input.size() != ptensor->getSize()) {
fprintf(stderr, "Stage Inference: input size not match, get %ld, expect %ld.\n",
float_input.size(), ptensor->getSize());
response->set_status(-1);
response->set_info(zdl::DlSystem::getLastErrorString());
return Status::OK;
}
std::copy(float_input.begin(), float_input.end(), ptensor->begin());
table.add("IN").add(tensor.name()).add(ShapeStr(ptensor)).add(ContentStr(ptensor)).eor();
}
}
// A tensor map for SNPE execution outputs
zdl::DlSystem::TensorMap outputTensorMap;
// Execute the multiple input tensorMap on the model with SNPE
bool success = false;
{
ScopeTimer timer("execute", false);
success = snpe->execute(inputTensorMap, outputTensorMap);
if (!success) {
response->set_status(-1);
response->set_info(zdl::DlSystem::getLastErrorString());
return Status::OK;
}
table.add("EXECUTE").add(std::to_string(timer.cost()) + "ms").eor();
}
{
// ScopeTimer timer("convert output");
auto out_names = outputTensorMap.getTensorNames();
for (size_t i = 0; i < out_names.size(); ++i) {
const char* name = out_names.at(i);
zdl::DlSystem::ITensor* ptensor = outputTensorMap.getTensor(name);
table.add("OUT").add(std::string(name)).add(ShapeStr(ptensor)).add(ContentStr(ptensor)).eor();
const size_t data_length = ptensor->getSize();
std::string result;
result.resize(sizeof(float) * data_length);
int j = 0;
for (auto it = ptensor->cbegin(); it != ptensor->cend(); ++it, j += sizeof(float)) {
float f = *it;
memcpy(&result[0] + j, reinterpret_cast<char*>(&f), sizeof(float));
}
auto shape = ptensor->getShape();
::mmdeploy::Tensor* pData = response->add_data();
pData->set_dtype("float32");
pData->set_name(name);
pData->set_data(result);
for (int j = 0; j < shape.rank(); ++j) {
pData->add_shape(shape[j]);
}
}
}
std::cout << table << std::endl << std::endl;
// build output status
response->set_status(0);
response->set_info("Stage Inference: success");
return Status::OK;
}
::grpc::Status InferenceServiceImpl::Destroy(::grpc::ServerContext* context,
const ::mmdeploy::Empty* request,
::mmdeploy::Reply* response) {
snpe.reset();
container.reset();
inputTensors.clear();
response->set_status(0);
zdl::SNPE::SNPEFactory::terminateLogging();
return Status::OK;
}
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef SERVICE_IMPL_H
#define SERVICE_IMPL_H
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <iostream>
#include <memory>
#include <string>
#include "DiagLog/IDiagLog.hpp"
#include "DlContainer/IDlContainer.hpp"
#include "DlSystem/DlEnums.hpp"
#include "DlSystem/DlError.hpp"
#include "DlSystem/ITensorFactory.hpp"
#include "DlSystem/IUserBuffer.hpp"
#include "DlSystem/PlatformConfig.hpp"
#include "DlSystem/RuntimeList.hpp"
#include "DlSystem/UserBufferMap.hpp"
#include "SNPE/SNPE.hpp"
#include "SNPE/SNPEBuilder.hpp"
#include "SNPE/SNPEFactory.hpp"
#include "inference.grpc.pb.h"
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::Status;
using mmdeploy::Empty;
using mmdeploy::Inference;
using mmdeploy::Model;
using mmdeploy::Reply;
using mmdeploy::Tensor;
using mmdeploy::TensorList;
// Logic and data behind the server's behavior.
class InferenceServiceImpl final : public Inference::Service {
::grpc::Status Echo(::grpc::ServerContext* context, const ::mmdeploy::Empty* request,
::mmdeploy::Reply* response) override;
// Init Model with model file
::grpc::Status Init(::grpc::ServerContext* context, const ::mmdeploy::Model* request,
::mmdeploy::Reply* response) override;
// Get output names
::grpc::Status OutputNames(::grpc::ServerContext* context, const ::mmdeploy::Empty* request,
::mmdeploy::Names* response) override;
// Inference with inputs
::grpc::Status Inference(::grpc::ServerContext* context, const ::mmdeploy::TensorList* request,
::mmdeploy::Reply* response) override;
// Destroy handle
::grpc::Status Destroy(::grpc::ServerContext* context, const ::mmdeploy::Empty* request,
::mmdeploy::Reply* response) override;
void SaveDLC(const ::mmdeploy::Model* request, const std::string& name);
void LoadFloatData(const std::string& data, std::vector<float>& vec);
zdl::DlSystem::Runtime_t CheckRuntime(zdl::DlSystem::Runtime_t runtime, bool& staticQuantization);
void Build(std::unique_ptr<zdl::DlContainer::IDlContainer>& container,
zdl::DlSystem::Runtime_t runtime, zdl::DlSystem::RuntimeList runtimeList,
bool useUserSuppliedBuffers, zdl::DlSystem::PlatformConfig platformConfig);
std::string ShapeStr(zdl::DlSystem::ITensor* pTensor);
std::string ContentStr(zdl::DlSystem::ITensor* pTensor);
std::unique_ptr<zdl::SNPE::SNPE> snpe;
std::unique_ptr<zdl::DlContainer::IDlContainer> container;
std::vector<std::unique_ptr<zdl::DlSystem::ITensor>> inputTensors;
zdl::DlSystem::TensorMap inputTensorMap;
};
#endif
/**
* \file sdk/load-and-run/src/text_table.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <array>
#include <iomanip>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include <vector>
namespace helper {
class TextTable {
public:
enum Level { Summary, Detail };
enum class Align : int { Left, Right, Mid };
TextTable() = default;
explicit TextTable(const std::string& table_name) : m_name(table_name) {}
TextTable& horizontal(char c) {
m_row.params.horizontal = c;
return *this;
}
TextTable& vertical(char c) {
m_row.params.vertical = c;
return *this;
}
TextTable& corner(char c) {
m_row.params.corner = c;
return *this;
}
TextTable& align(Align v) {
m_row.params.align = v;
return *this;
}
TextTable& padding(size_t w) {
m_padding = w;
return *this;
}
TextTable& prefix(const std::string& str) {
m_prefix = str;
return *this;
}
template <typename T>
TextTable& add(const T& value) {
m_row.values.emplace_back(value);
if (m_cols_max_w.size() < m_row.values.size()) {
m_cols_max_w.emplace_back(m_row.values.back().length());
} else {
size_t i = m_row.values.size() - 1;
m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length());
}
return *this;
}
template <typename T, typename std::enable_if<std::is_floating_point<T>::value, bool>::type = 0>
TextTable& add(const T& value) {
std::stringstream ss;
ss << std::setiosflags(std::ios::fixed) << std::setprecision(2);
ss << value;
m_row.values.emplace_back(ss.str());
if (m_cols_max_w.size() < m_row.values.size()) {
m_cols_max_w.emplace_back(m_row.values.back().length());
} else {
size_t i = m_row.values.size() - 1;
m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length());
}
return *this;
}
template <typename T, typename std::enable_if<std::is_integral<T>::value, bool>::type = 0>
TextTable& add(const T& value) {
m_row.values.emplace_back(std::to_string(value));
return *this;
}
void eor() {
m_rows.emplace_back(m_row);
adjuster_last_row();
m_row.values.clear();
}
void reset() {
m_row = {};
m_cols_max_w.clear();
m_padding = 0;
m_rows.clear();
}
void show(std::ostream& os) {
if (m_rows.empty()) return;
auto& last_row = m_rows.front();
bool first = true;
for (auto& row : m_rows) {
auto& lrow = (last_row.values.size() * char_length(last_row.params.horizontal)) >
(row.values.size() * char_length(row.params.horizontal))
? last_row
: row;
// line before row
if (lrow.params.horizontal) {
if (not first) os << std::endl;
os << m_prefix;
if (lrow.params.corner) os << lrow.params.corner;
size_t skip_size = 0;
// table name
if (first) {
os << m_name;
skip_size = m_name.length();
}
for (size_t i = 0; i < lrow.values.size(); ++i) {
auto max_w = m_cols_max_w.at(i) + m_padding * 2;
if (max_w + char_length(lrow.params.corner) <= skip_size) {
skip_size = skip_size - max_w - char_length(lrow.params.corner);
continue;
}
size_t rest = max_w + char_length(lrow.params.corner) - skip_size;
skip_size = 0;
if (rest > char_length(lrow.params.corner)) {
os << std::string(rest - char_length(lrow.params.corner), lrow.params.horizontal);
rest = char_length(lrow.params.corner);
}
if (rest > 0 && lrow.params.corner) os << lrow.params.corner;
}
} else if (first) {
os << m_prefix << ' ' << m_name;
}
first = false;
os << std::endl << m_prefix;
if (row.params.vertical) os << row.params.vertical;
// row
for (size_t i = 0; i < row.values.size(); ++i) {
auto& str = row.values.at(i);
auto max_w = m_cols_max_w.at(i) + 2 * m_padding;
if (row.params.align == Align::Mid) {
mid(os, str, max_w);
} else if (row.params.align == Align::Left) {
os << std::setw(max_w) << std::left << str;
} else {
os << std::setw(max_w) << std::right << str;
}
if (row.params.vertical) os << row.params.vertical;
}
last_row = row;
}
if (last_row.params.horizontal) {
os << std::endl << m_prefix;
if (last_row.params.corner) os << last_row.params.corner;
for (size_t i = 0; i < last_row.values.size(); ++i) {
auto max_w = m_cols_max_w.at(i);
std::string tmp(max_w + m_padding * 2, last_row.params.horizontal);
os << tmp;
if (last_row.params.corner) os << last_row.params.corner;
}
}
}
private:
void adjuster_last_row() {
if (m_rows.empty()) return;
auto& row = m_rows.back();
if (row.params.horizontal == 0 or row.params.vertical == 0) {
row.params.corner = 0;
}
if (row.params.horizontal != 0 && row.params.vertical != 0 && row.params.corner == 0) {
row.params.corner = row.params.horizontal;
}
}
inline void mid(std::ostream& os, const std::string& str, size_t max_w) {
size_t l = (max_w - str.length()) / 2 + str.length();
size_t r = max_w - l;
os << std::setw(l) << std::right << str;
if (r > 0) os << std::setw(r) << ' ';
}
inline size_t char_length(char c) { return c ? 1 : 0; }
std::string m_name;
std::vector<size_t> m_cols_max_w;
size_t m_padding = 0;
std::string m_prefix = "";
struct Row {
std::vector<std::string> values;
struct Params {
Align align = Align::Left;
char horizontal = '-', vertical = '|', corner = '+';
} params;
};
std::vector<Row> m_rows;
Row m_row;
};
inline std::ostream& operator<<(std::ostream& stream, TextTable& table) {
table.show(stream);
return stream;
}
} // namespace helper
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import collect_env as collect_base_env
from mmengine.utils import get_git_hash
import mmdeploy
from mmdeploy.utils import get_codebase_version, get_root_logger
def collect_env():
"""Collect the information of the running environments."""
env_info = collect_base_env()
env_info['MMDeploy'] = f'{mmdeploy.__version__}+{get_git_hash()[:7]}'
return env_info
def check_backend():
from mmdeploy.backend.base import get_backend_manager
from mmdeploy.utils import Backend
exclude_backend_lists = [Backend.DEFAULT, Backend.PYTORCH, Backend.SDK]
backend_lists = [
backend for backend in Backend if backend not in exclude_backend_lists
]
for backend in backend_lists:
backend_mgr = get_backend_manager(backend.value)
backend_mgr.check_env(logger.info)
def check_codebase():
codebase_versions = get_codebase_version()
for k, v in codebase_versions.items():
logger.info(f'{k}:\t{v}')
if __name__ == '__main__':
logger = get_root_logger()
logger.info('\n')
logger.info('**********Environmental information**********')
for name, val in collect_env().items():
logger.info('{}: {}'.format(name, val))
logger.info('\n')
logger.info('**********Backend information**********')
check_backend()
logger.info('\n')
logger.info('**********Codebase information**********')
check_codebase()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import os.path as osp
from functools import partial
import mmengine
import torch.multiprocessing as mp
from torch.multiprocessing import Process, set_start_method
from mmdeploy.apis import (create_calib_input_data, extract_model,
get_predefined_partition_cfg, torch2onnx,
torch2torchscript, visualize_model)
from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.apis.utils import to_backend
from mmdeploy.backend.sdk.export_info import export2SDK
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
get_ir_config, get_partition_config,
get_root_logger, load_config, target_wrapper)
def parse_args():
parser = argparse.ArgumentParser(description='Export model to backends.')
parser.add_argument('deploy_cfg', help='deploy config path')
parser.add_argument('model_cfg', help='model config path')
parser.add_argument('checkpoint', help='model checkpoint path')
parser.add_argument('img', help='image used to convert model model')
parser.add_argument(
'--test-img',
default=None,
type=str,
nargs='+',
help='image used to test model')
parser.add_argument(
'--work-dir',
default=os.getcwd(),
help='the dir to save logs and models')
parser.add_argument(
'--calib-dataset-cfg',
help='dataset config path used to calibrate in int8 mode. If not \
specified, it will use "val" dataset in model config instead.',
default=None)
parser.add_argument(
'--device', help='device used for conversion', default='cpu')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
parser.add_argument(
'--dump-info', action='store_true', help='Output information for SDK')
parser.add_argument(
'--quant-image-dir',
default=None,
help='Image directory for quantize model.')
parser.add_argument(
'--quant', action='store_true', help='Quantize model to low bit.')
parser.add_argument(
'--uri',
default='192.168.1.1:60000',
help='Remote ipv4:port or ipv6:port for inference on edge device.')
args = parser.parse_args()
return args
def create_process(name, target, args, kwargs, ret_value=None):
logger = get_root_logger()
logger.info(f'{name} start.')
log_level = logger.level
wrap_func = partial(target_wrapper, target, log_level, ret_value)
process = Process(target=wrap_func, args=args, kwargs=kwargs)
process.start()
process.join()
if ret_value is not None:
if ret_value.value != 0:
logger.error(f'{name} failed.')
exit(1)
else:
logger.info(f'{name} success.')
def torch2ir(ir_type: IR):
"""Return the conversion function from torch to the intermediate
representation.
Args:
ir_type (IR): The type of the intermediate representation.
"""
if ir_type == IR.ONNX:
return torch2onnx
elif ir_type == IR.TORCHSCRIPT:
return torch2torchscript
else:
raise KeyError(f'Unexpected IR type {ir_type}')
def main():
args = parse_args()
set_start_method('spawn', force=True)
logger = get_root_logger()
log_level = logging.getLevelName(args.log_level)
logger.setLevel(log_level)
pipeline_funcs = [
torch2onnx, torch2torchscript, extract_model, create_calib_input_data
]
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs)
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs)
deploy_cfg_path = args.deploy_cfg
model_cfg_path = args.model_cfg
checkpoint_path = args.checkpoint
quant = args.quant
quant_image_dir = args.quant_image_dir
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
# create work_dir if not
mmengine.mkdir_or_exist(osp.abspath(args.work_dir))
if args.dump_info:
export2SDK(
deploy_cfg,
model_cfg,
args.work_dir,
pth=checkpoint_path,
device=args.device)
ret_value = mp.Value('d', 0, lock=False)
# convert to IR
ir_config = get_ir_config(deploy_cfg)
ir_save_file = ir_config['save_file']
ir_type = IR.get(ir_config['type'])
torch2ir(ir_type)(
args.img,
args.work_dir,
ir_save_file,
deploy_cfg_path,
model_cfg_path,
checkpoint_path,
device=args.device)
# convert backend
ir_files = [osp.join(args.work_dir, ir_save_file)]
# partition model
partition_cfgs = get_partition_config(deploy_cfg)
if partition_cfgs is not None:
if 'partition_cfg' in partition_cfgs:
partition_cfgs = partition_cfgs.get('partition_cfg', None)
else:
assert 'type' in partition_cfgs
partition_cfgs = get_predefined_partition_cfg(
deploy_cfg, partition_cfgs['type'])
origin_ir_file = ir_files[0]
ir_files = []
for partition_cfg in partition_cfgs:
save_file = partition_cfg['save_file']
save_path = osp.join(args.work_dir, save_file)
start = partition_cfg['start']
end = partition_cfg['end']
dynamic_axes = partition_cfg.get('dynamic_axes', None)
extract_model(
origin_ir_file,
start,
end,
dynamic_axes=dynamic_axes,
save_file=save_path)
ir_files.append(save_path)
# calib data
calib_filename = get_calib_filename(deploy_cfg)
if calib_filename is not None:
calib_path = osp.join(args.work_dir, calib_filename)
create_calib_input_data(
calib_path,
deploy_cfg_path,
model_cfg_path,
checkpoint_path,
dataset_cfg=args.calib_dataset_cfg,
dataset_type='val',
device=args.device)
backend_files = ir_files
# convert backend
backend = get_backend(deploy_cfg)
# preprocess deploy_cfg
if backend == Backend.RKNN:
# TODO: Add this to task_processor in the future
import tempfile
from mmdeploy.utils import (get_common_config, get_normalization,
get_quantization_config,
get_rknn_quantization)
quantization_cfg = get_quantization_config(deploy_cfg)
common_params = get_common_config(deploy_cfg)
if get_rknn_quantization(deploy_cfg) is True:
transform = get_normalization(model_cfg)
common_params.update(
dict(
mean_values=[transform['mean']],
std_values=[transform['std']]))
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
with open(dataset_file, 'w') as f:
f.writelines([osp.abspath(args.img)])
if quantization_cfg.get('dataset', None) is None:
quantization_cfg['dataset'] = dataset_file
if backend == Backend.ASCEND:
# TODO: Add this to backend manager in the future
if args.dump_info:
from mmdeploy.backend.ascend import update_sdk_pipeline
update_sdk_pipeline(args.work_dir)
if backend == Backend.VACC:
# TODO: Add this to task_processor in the future
from onnx2vacc_quant_dataset import get_quant
from mmdeploy.utils import get_model_inputs
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
model_inputs = get_model_inputs(deploy_cfg)
for onnx_path, model_input in zip(ir_files, model_inputs):
quant_mode = model_input.get('qconfig', {}).get('dtype', 'fp16')
assert quant_mode in ['int8',
'fp16'], quant_mode + ' not support now'
shape_dict = model_input.get('shape', {})
if quant_mode == 'int8':
create_process(
'vacc quant dataset',
target=get_quant,
args=(deploy_cfg, model_cfg, shape_dict, checkpoint_path,
args.work_dir, args.device),
kwargs=dict(),
ret_value=ret_value)
# convert to backend
PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
if backend == Backend.TENSORRT:
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
backend_files = to_backend(
backend,
ir_files,
work_dir=args.work_dir,
deploy_cfg=deploy_cfg,
log_level=log_level,
device=args.device,
uri=args.uri)
# ncnn quantization
if backend == Backend.NCNN and quant:
from onnx2ncnn_quant_table import get_table
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
model_param_paths = backend_files[::2]
model_bin_paths = backend_files[1::2]
backend_files = []
for onnx_path, model_param_path, model_bin_path in zip(
ir_files, model_param_paths, model_bin_paths):
deploy_cfg, model_cfg = load_config(deploy_cfg_path,
model_cfg_path)
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
onnx_path, args.work_dir)
create_process(
'ncnn quant table',
target=get_table,
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
quant_table, quant_image_dir, args.device),
kwargs=dict(),
ret_value=ret_value)
create_process(
'ncnn_int8',
target=ncnn2int8,
args=(model_param_path, model_bin_path, quant_table,
quant_param, quant_bin),
kwargs=dict(),
ret_value=ret_value)
backend_files += [quant_param, quant_bin]
if args.test_img is None:
args.test_img = args.img
extra = dict(
backend=backend,
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
show_result=args.show)
if backend == Backend.SNPE:
extra['uri'] = args.uri
# get backend inference result, try render
create_process(
f'visualize {backend.value} model',
target=visualize_model,
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
args.device),
kwargs=extra,
ret_value=ret_value)
# get pytorch model inference result, try visualize if possible
create_process(
'visualize pytorch model',
target=visualize_model,
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
args.test_img, args.device),
kwargs=dict(
backend=Backend.PYTORCH,
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
show_result=args.show),
ret_value=ret_value)
logger.info('All process success.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
import argparse
import os
import os.path as osp
import pathlib
import shutil
import subprocess
from glob import glob
import mmcv
import yaml
from mmdeploy.backend.sdk.export_info import (get_preprocess,
get_transform_static)
from mmdeploy.utils import get_root_logger, load_config
print(pathlib.Path(__file__).resolve())
MMDEPLOY_PATH = pathlib.Path(__file__).parent.parent.parent.resolve()
ELENA_BIN = 'OpFuse'
logger = get_root_logger()
CODEBASE = [
'mmpretrain', 'mmdetection', 'mmpose', 'mmrotate', 'mmocr',
'mmsegmentation', 'mmagic'
]
DEPLOY_CFG = {
'Image Classification': 'configs/mmpretrain/classification_tensorrt_dynamic-224x224-224x224.py',
'Object Detection': 'configs/mmdet/detection/detection_tensorrt_static-800x1344.py',
'Instance Segmentation': 'configs/mmdet/instance-seg/instance-seg_tensorrt_static-800x1344.py',
'Semantic Segmentation': 'configs/mmseg/segmentation_tensorrt_static-512x512.py',
'Oriented Object Detection': 'configs/mmrotate/rotated-detection_tensorrt-fp16_dynamic-320x320-1024x1024.py',
'Text Recognition': 'configs/mmocr/text-recognition/text-recognition_tensorrt_static-32x32.py',
'Text Detection': 'configs/mmocr/text-detection/text-detection_tensorrt_static-512x512.py',
'Restorers': 'configs/mmagic/super-resolution/super-resolution_tensorrt_static-256x256.py'
} # yapf: disable
INFO = {
'cpu':
'''
using std::string;
void FuseFunc(void* stream, uint8_t* data_in, int src_h, int src_w, const char* format,
int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left,
int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1,
float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h,
int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) {
const char* interpolation_ = "nearest";
if (strcmp(interpolation, "bilinear") == 0) {
interpolation_ = "bilinear";
}
FuseKernel(resize_h, resize_w, crop_h, crop_w, crop_top, crop_left, mean0, mean1, mean2, std0, std1, std2,
pad_h, pad_w, pad_top, pad_left, pad_bottom, pad_right, pad_value, data_in, data_out,
src_h, src_w, format, interpolation_);
}
REGISTER_FUSE_KERNEL(#TAG#_cpu, "#TAG#_cpu",
FuseFunc);
''',
'cuda':
'''
void FuseFunc(void* stream, uint8_t* data_in, int src_h, int src_w, const char* format,
int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left,
int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1,
float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h,
int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) {
cudaStream_t stream_ = (cudaStream_t)stream;
const char* interpolation_ = "nearest";
if (strcmp(interpolation, "bilinear") == 0) {
interpolation_ = "bilinear";
}
FuseKernelCU(stream_, resize_h, resize_w, crop_h, crop_w, crop_top, crop_left, mean0, mean1, mean2, std0,
std1, std2, pad_h, pad_w, pad_top, pad_left, pad_bottom, pad_right, pad_value, data_in,
data_out, dst_h, dst_w, src_h, src_w, format, interpolation_);
}
REGISTER_FUSE_KERNEL(#TAG#_cuda, "#TAG#_cuda",
FuseFunc);
'''
}
def parse_args():
parser = argparse.ArgumentParser(description='Extract transform.')
parser.add_argument(
'root_path', help='parent path to codebase(mmdetection for example)')
args = parser.parse_args()
return args
def append_info(device, tag):
info = INFO[device]
info = info.replace('#TAG#', tag)
src_file = 'source.c' if device == 'cpu' else 'source.cu'
nsp = f'namespace {device}_{tag}' + ' {\n'
with open(src_file, 'r', encoding='utf-8') as f:
data = f.readlines()
for i, line in enumerate(data):
if '_Kernel' in line or '__device__' in line:
data.insert(i, nsp)
data.insert(i, '#include "elena_registry.h"\n')
break
for i, line in enumerate(data):
data[i] = line.replace('extern "C"', '')
data.append(info)
data.append('}')
with open(src_file, 'w', encoding='utf-8') as f:
for line in data:
f.write(line)
def generate_source_code(preprocess, transform_static, tag, args):
kernel_base_dir = osp.join(MMDEPLOY_PATH, 'csrc', 'mmdeploy', 'preprocess',
'elena')
cpu_work_dir = osp.join(kernel_base_dir, 'cpu_kernel')
cuda_work_dir = osp.join(kernel_base_dir, 'cuda_kernel')
dst_cpu_kernel_file = osp.join(cpu_work_dir, f'{tag}.cpp')
dst_cuda_kernel_file = osp.join(cuda_work_dir, f'{tag}.cu')
dst_cpu_elena_header_file = osp.join(cpu_work_dir, 'elena_int.h')
dst_cuda_elena_header_file = osp.join(cuda_work_dir, 'elena_int.h')
json_work_dir = osp.join(kernel_base_dir, 'json')
preprocess_json_path = osp.join(json_work_dir, f'{tag}_preprocess.json')
static_json_path = osp.join(json_work_dir, f'{tag}_static.json')
if osp.exists(preprocess_json_path):
return
mmengine.dump(preprocess, preprocess_json_path, sort_keys=False, indent=4)
mmengine.dump(
transform_static, static_json_path, sort_keys=False, indent=4)
gen_cpu_cmd = f'{ELENA_BIN} {static_json_path} cpu'
res = subprocess.run(gen_cpu_cmd, shell=True)
if res.returncode == 0:
append_info('cpu', tag)
shutil.copyfile('source.c', dst_cpu_kernel_file)
shutil.copyfile('elena_int.h', dst_cpu_elena_header_file)
os.remove('source.c')
gen_cuda_cmd = f'{ELENA_BIN} {static_json_path} cuda'
res = subprocess.run(gen_cuda_cmd, shell=True)
if res.returncode == 0:
append_info('cuda', tag)
shutil.copyfile('source.cu', dst_cuda_kernel_file)
shutil.copyfile('elena_int.h', dst_cuda_elena_header_file)
os.remove('source.cu')
os.remove('elena_int.h')
def extract_one_model(deploy_cfg_, model_cfg_, args):
deploy_cfg, model_cfg = load_config(deploy_cfg_, model_cfg_)
preprocess = get_preprocess(deploy_cfg, model_cfg, 'cuda')
preprocess['model_cfg'] = model_cfg_
transform_static, tag = get_transform_static(preprocess['transforms'])
if tag is not None:
generate_source_code(preprocess, transform_static, tag, args)
def extract_one_metafile(metafile, codebase, args):
with open(metafile, encoding='utf-8') as f:
yaml_info = yaml.load(f, Loader=yaml.FullLoader)
known_task = list(DEPLOY_CFG.keys())
for model in yaml_info['Models']:
try:
cfg = model['Config']
task_name = model['Results'][0]['Task']
if task_name not in known_task:
continue
deploy_cfg = osp.join(MMDEPLOY_PATH, DEPLOY_CFG[task_name])
model_cfg = osp.join(args.root_path, codebase, cfg)
extract_one_model(deploy_cfg, model_cfg, args)
except Exception:
pass
def main():
args = parse_args()
global ELENA_BIN
elena_path = osp.abspath(
os.path.join(MMDEPLOY_PATH, 'third_party', 'CVFusion', 'build',
'examples', 'MMDeploy', 'OpFuse'))
if osp.exists(elena_path):
ELENA_BIN = elena_path
for cb in CODEBASE:
if not os.path.exists(osp.join(args.root_path, cb)):
logger.warning(f'skip codebase {cb} because it isn\'t exists.')
continue
metafile_pattern = osp.join(args.root_path, cb, 'configs', '**/*.yml')
metafiles = glob(metafile_pattern, recursive=True)
for metafile in metafiles:
extract_one_metafile(metafile, cb, args)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os.path as osp
import onnx
import onnx.helper
from mmdeploy.apis.onnx import extract_partition
from mmdeploy.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(
description='Extract model based on markers.')
parser.add_argument('input_model', help='Input ONNX model')
parser.add_argument('output_model', help='Output ONNX model')
parser.add_argument(
'--start',
help='Start markers, format: func:type, e.g. backbone:input')
parser.add_argument('--end', help='End markers')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
args.start = args.start.split(',') if args.start else []
args.end = args.end.split(',') if args.end else []
return args
def collect_avaiable_marks(model):
marks = []
for node in model.graph.node:
if node.op_type == 'Mark':
for attr in node.attribute:
if attr.name == 'func':
func = str(onnx.helper.get_attribute_value(attr), 'utf-8')
if func not in marks:
marks.append(func)
return marks
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
model = onnx.load(args.input_model)
marks = collect_avaiable_marks(model)
logger.info('Available marks:\n {}'.format('\n '.join(marks)))
extracted_model = extract_partition(model, args.start, args.end)
if osp.splitext(args.output_model)[-1] != '.onnx':
args.output_model += '.onnx'
onnx.save(extracted_model, args.output_model)
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import yaml
from mmengine import Config
from mmdeploy.utils import get_backend, get_task_type, load_config
def parse_args():
parser = argparse.ArgumentParser(
description='from yaml export markdown table')
parser.add_argument('yml_file', help='input yml config path')
parser.add_argument('output', help='output markdown file path')
parser.add_argument(
'--backends',
nargs='+',
help='backends you want to generate',
default=[
'onnxruntime', 'tensorrt', 'torchscript', 'pplnn', 'openvino',
'ncnn'
])
args = parser.parse_args()
return args
def main():
args = parse_args()
assert osp.exists(args.yml_file), f'File not exists: {args.yml_file}'
output_dir, _ = osp.split(args.output)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
header = ['model', 'task'] + args.backends
aligner = [':--'] * 2 + [':--:'] * len(args.backends)
def write_row_f(writer, row):
writer.write('|' + '|'.join(row) + '|\n')
print(f'Processing{args.yml_file}')
with open(args.yml_file, 'r') as reader, open(args.output, 'w') as writer:
config = yaml.load(reader, Loader=yaml.FullLoader)
config = Config(config)
write_row_f(writer, header)
write_row_f(writer, aligner)
repo_url = config.globals.repo_url
for i in range(len(config.models)):
name = config.models[i].name
model_configs = config.models[i].model_configs
pipelines = config.models[i].pipelines
config_url = osp.join(repo_url, model_configs[0])
config_url, _ = osp.split(config_url)
support_backends = {b: 'N' for b in args.backends}
deploy_config = [
pipelines[i].deploy_config for i in range(len(pipelines))
]
cfg = [
load_config(deploy_config[i])
for i in range(len(deploy_config))
]
task = [
get_task_type(cfg[i][0]).value
for i in range(len(deploy_config))
]
backend_type = [
get_backend(cfg[i][0]).value
for i in range(len(deploy_config))
]
for i in range(len(deploy_config)):
support_backends[backend_type[i]] = 'Y'
support_backends = [support_backends[i] for i in args.backends]
model_name = f'[{name}]({config_url})'
row = [model_name, task[i]] + support_backends
write_row_f(writer, row)
print(f'Save to {args.output}')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
from mmdeploy.apis.snpe import from_onnx
from mmdeploy.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(
description='Convert ONNX to snpe dlc format.')
parser.add_argument('onnx_path', help='ONNX model path')
parser.add_argument('output_prefix', help='output snpe dlc model path')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
onnx_path = args.onnx_path
output_prefix = args.output_prefix
logger.info(f'onnx2dlc: \n\tonnx_path: {onnx_path} ')
from_onnx(onnx_path, output_prefix)
logger.info('onnx2dlc success.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
from mmdeploy.apis.ncnn import from_onnx
from mmdeploy.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='Convert ONNX to ncnn.')
parser.add_argument('onnx_path', help='ONNX model path')
parser.add_argument('output_prefix', help='output ncnn model path')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
onnx_path = args.onnx_path
output_prefix = args.output_prefix
logger.info(f'mmdeploy_onnx2ncnn: \n\tonnx_path: {onnx_path} ')
from_onnx(onnx_path, output_prefix)
logger.info('mmdeploy_onnx2ncnn success.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
from copy import deepcopy
from mmengine import Config
from torch.utils.data import DataLoader
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_root_logger, load_config
def get_table(onnx_path: str,
deploy_cfg: Config,
model_cfg: Config,
output_onnx_path: str,
output_quant_table_path: str,
image_dir: str = None,
device: str = 'cuda',
dataset_type: str = 'val'):
input_shape = None
# setup input_shape if existed in `onnx_config`
if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config:
input_shape = deploy_cfg.onnx_config.input_shape
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader'])
calib_dataloader['batch_size'] = 1
# build calibration dataloader. If img dir not specified, use val dataset.
if image_dir is not None:
from quant_image_dataset import QuantizationImageDataset
dataset = QuantizationImageDataset(
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
def collate(data_batch):
return data_batch[0]
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate)
else:
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
calib_dataloader['dataset'] = dataset
dataloader = task_processor.build_dataloader(calib_dataloader)
data_preprocessor = task_processor.build_data_preprocessor()
# get an available input shape randomly
for _, input_data in enumerate(dataloader):
input_data = data_preprocessor(input_data)
input_tensor = input_data['inputs']
input_shape = input_tensor.shape
collate_fn = lambda x: data_preprocessor(x)['inputs'].to( # noqa: E731
device)
from ppq import QuantizationSettingFactory, TargetPlatform
from ppq.api import export_ppq_graph, quantize_onnx_model
# settings for ncnn quantization
quant_setting = QuantizationSettingFactory.default_setting()
quant_setting.equalization = False
quant_setting.dispatcher = 'conservative'
# quantize the model
quantized = quantize_onnx_model(
onnx_import_file=onnx_path,
calib_dataloader=dataloader,
calib_steps=max(8, min(512, len(dataset))),
input_shape=input_shape,
setting=quant_setting,
collate_fn=collate_fn,
platform=TargetPlatform.NCNN_INT8,
device=device,
verbose=1)
# export quantized graph and quant table
export_ppq_graph(
graph=quantized,
platform=TargetPlatform.NCNN_INT8,
graph_save_to=output_onnx_path,
config_save_to=output_quant_table_path)
return
def parse_args():
parser = argparse.ArgumentParser(
description='Generate ncnn quant table from ONNX.')
parser.add_argument('--onnx', help='ONNX model path')
parser.add_argument('--deploy-cfg', help='Input deploy config path')
parser.add_argument('--model-cfg', help='Input model config path')
parser.add_argument('--out-onnx', help='Output onnx path')
parser.add_argument('--out-table', help='Output quant table path')
parser.add_argument(
'--image-dir',
type=str,
default=None,
help='Calibration Image Directory.')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
onnx_path = args.onnx
deploy_cfg, model_cfg = load_config(args.deploy_cfg, args.model_cfg)
quant_table_path = args.out_table
quant_onnx_path = args.out_onnx
image_dir = args.image_dir
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
quant_table_path, image_dir)
logger.info('onnx2ncnn_quant_table success.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import collections
import logging
from mmdeploy.apis.pplnn import from_onnx
from mmdeploy.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='Convert ONNX to PPLNN.')
parser.add_argument('onnx_path', help='ONNX model path')
parser.add_argument(
'output_prefix', help='output PPLNN algorithm prefix in json format')
parser.add_argument(
'--device',
help='`the device of model during conversion',
default='cuda:0')
parser.add_argument(
'--opt-shapes',
help='`Optical shapes for PPLNN optimization. The shapes must be able'
'to be evaluated by python, e,g., `[1, 3, 224, 224]`',
default='[1, 3, 224, 224]')
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
onnx_path = args.onnx_path
output_prefix = args.output_prefix
device = args.device
input_shapes = eval(args.opt_shapes)
assert isinstance(
input_shapes, collections.Sequence), \
'The opt-shape must be a sequence.'
assert isinstance(input_shapes[0], int) or (isinstance(
input_shapes[0], collections.Sequence)), \
'The opt-shape must be a sequence of int or a sequence of sequence.'
if isinstance(input_shapes[0], int):
input_shapes = [input_shapes]
logger.info(f'onnx2pplnn: \n\tonnx_path: {onnx_path} '
f'\n\toutput_prefix: {output_prefix}'
f'\n\topt_shapes: {input_shapes}')
from_onnx(onnx_path, output_prefix, device, input_shapes)
logger.info('onnx2pplnn success.')
if __name__ == '__main__':
main()
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
from mmdeploy.backend.tensorrt import from_onnx
from mmdeploy.backend.tensorrt.utils import get_trt_log_level
from mmdeploy.utils import (get_common_config, get_model_inputs,
get_root_logger, load_config)
def parse_args():
parser = argparse.ArgumentParser(description='Convert ONNX to TensorRT.')
parser.add_argument('deploy_cfg', help='deploy config path')
parser.add_argument('onnx_path', help='ONNX model path')
parser.add_argument('output_prefix', help='output TensorRT engine prefix')
parser.add_argument('--device-id', help='`the CUDA device id', default=0)
parser.add_argument(
'--calib-file',
help='`the calibration data used to calibrate engine to int8',
default=None)
parser.add_argument(
'--log-level',
help='set log level',
default='INFO',
choices=list(logging._nameToLevel.keys()))
args = parser.parse_args()
return args
def main():
args = parse_args()
logger = get_root_logger(log_level=args.log_level)
deploy_cfg_path = args.deploy_cfg
deploy_cfg = load_config(deploy_cfg_path)[0]
onnx_path = args.onnx_path
output_prefix = args.output_prefix
device_id = args.device_id
calib_file = args.calib_file
model_id = 0
common_params = get_common_config(deploy_cfg)
model_params = get_model_inputs(deploy_cfg)[model_id]
final_params = common_params
final_params.update(model_params)
int8_param = final_params.get('int8_param', dict())
if calib_file is not None:
int8_param['calib_file'] = calib_file
# do not support partition model calibration for now
int8_param['model_type'] = 'end2end'
logger.info(f'onnx2tensorrt: \n\tonnx_path: {onnx_path} '
f'\n\tdeploy_cfg: {deploy_cfg_path}')
from_onnx(
onnx_path,
output_prefix,
input_shapes=final_params['input_shapes'],
log_level=get_trt_log_level(),
fp16_mode=final_params.get('fp16_mode', False),
int8_mode=final_params.get('int8_mode', False),
int8_param=int8_param,
max_workspace_size=final_params.get('max_workspace_size', 0),
device_id=device_id)
logger.info('onnx2tensorrt success.')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment