Commit 7ae93d70 authored by limm's avatar limm
Browse files

add tests part code

parent abaad570
Pipeline #2815 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
from .module import C2, C, func
__all__ = ['func', 'C', 'C2']
# Copyright (c) OpenMMLab. All rights reserved.
def func():
return 1
class C:
def method(self):
return 1
class C2(C):
pass
# Copyright (c) OpenMMLab. All rights reserved.
import torch
try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext
from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter
from mmdeploy.core.rewriters.rewriter_utils import collect_env
from mmdeploy.utils.constants import IR, Backend
def test_function_rewriter():
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([2, 4, 6, 8, 10])
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.mul', backend='tensorrt')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.add', backend='tensorrt')
def sub_func(x, y):
ctx = FUNCTION_REWRITER.get_context('torch.add')
assert hasattr(ctx, 'cfg')
assert hasattr(ctx, 'origin_func')
return x - y
cfg = dict()
with RewriterContext(cfg, backend='tensorrt'):
result = torch.add(x, y)
# replace add with sub
torch_assert_close(result, x - y)
result = torch.mul(x, y)
# replace add with sub
torch_assert_close(result, x - y)
result = torch.add(x, y)
# recovery origin function
torch_assert_close(result, x + y)
with RewriterContext(cfg):
result = torch.add(x, y)
# replace should not happen with wrong backend
torch_assert_close(result, x + y)
# test different config
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.add', backend='default')
def mul_func_class(x, y):
return x * y
with RewriterContext(cfg, backend='tensorrt'):
result = x.add(y)
# replace add with multi
torch_assert_close(result, x * y)
result = x.add(y)
# recovery origin function
torch_assert_close(result, x + y)
with RewriterContext(cfg):
result = x.add(y)
# replace add with multi
torch_assert_close(result, x * y)
# test origin_func
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.add', backend='default')
def origin_add_func(x, y, **kwargs):
ctx = FUNCTION_REWRITER.get_context('torch.add')
return ctx.origin_func(x, y, **kwargs) + 1
with RewriterContext(cfg):
result = torch.add(x, y)
# replace with origin + 1
torch_assert_close(result, x + y + 1)
# remove torch.add
del FUNCTION_REWRITER._origin_functions[-1]
torch_assert_close(torch.add(x, y), x + y)
FUNCTION_REWRITER._registry.remove_record(sub_func)
FUNCTION_REWRITER._registry.remove_record(mul_func_class)
FUNCTION_REWRITER._registry.remove_record(origin_add_func)
def test_rewrite_empty_function():
function_rewriter = FunctionRewriter()
@function_rewriter.register_rewriter(func_name='torch.abcdefghijklmn')
def func(x, y):
return x + y
function_rewriter.enter()
assert len(function_rewriter._origin_functions) == 0
function_rewriter.exit()
class TestHomonymicRewriter:
def test_rewrite_homonymic_methods(self):
import package
path1 = 'package.C.method'
path2 = 'package.module.C.method'
c = package.C()
function_rewriter = FunctionRewriter()
assert c.method() == 1
@function_rewriter.register_rewriter(func_name=path1)
def func_2(self):
return 2
@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(self):
return 3
function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
assert c.method() == 3
function_rewriter.exit()
assert c.method() == 1
function_rewriter2 = FunctionRewriter()
@function_rewriter2.register_rewriter(
func_name=path1, backend=Backend.NCNN.value)
def func_4(self):
return 4
@function_rewriter2.register_rewriter(func_name=path2)
def func_5(self):
return 5
function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
assert c.method() == 4
function_rewriter2.exit()
assert c.method() == 1
def test_rewrite_derived_methods():
import package
path1 = 'package.C.method'
path2 = 'package.C2.method'
base_obj = package.C()
derived_obj = package.C2()
assert base_obj.method() == 1
assert derived_obj.method() == 1
function_rewriter = FunctionRewriter()
@function_rewriter.register_rewriter(func_name=path1)
def func_2(self):
return 2
@function_rewriter.register_rewriter(
func_name=path2, backend=Backend.NCNN.value)
def func_3(self):
return 3
function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT))
assert base_obj.method() == 2
assert derived_obj.method() == 2
function_rewriter.exit()
function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT))
assert base_obj.method() == 2
assert derived_obj.method() == 3
function_rewriter.exit()
assert base_obj.method() == 1
assert derived_obj.method() == 1
# Check if the recovery is correct
function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT))
assert base_obj.method() == 2
assert derived_obj.method() == 2
function_rewriter.exit()
assert base_obj.method() == 1
assert derived_obj.method() == 1
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import onnx
import torch
from mmdeploy.core import RewriterContext, mark
from mmdeploy.core.optimizers import attribute_to_dict
from mmdeploy.utils.constants import IR, Backend
output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
def test_mark():
@mark('add', inputs=['a', 'b'], outputs='c')
def add(x, y):
return torch.add(x, y)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return add(x, y)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)
torch.onnx.export(model, (x, y), output_file)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'Mark'
assert nodes[0].domain == 'mmdeploy'
assert attribute_to_dict(nodes[0].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=0,
type='input',
name='a',
shape=[2, 3, 4])
assert nodes[1].op_type == 'Mark'
assert nodes[1].domain == 'mmdeploy'
assert attribute_to_dict(nodes[1].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=1,
type='input',
name='b',
shape=[2, 3, 4])
assert nodes[2].op_type == 'Add'
assert nodes[3].op_type == 'Mark'
assert nodes[3].domain == 'mmdeploy'
assert attribute_to_dict(nodes[3].attribute) == dict(
dtype=1,
func='add',
func_id=0,
id=0,
type='output',
name='c',
shape=[2, 3, 4])
with RewriterContext(
cfg=None, backend=Backend.TORCHSCRIPT.value,
ir=IR.TORCHSCRIPT), torch.no_grad(), torch.jit.optimized_execution(
True):
torch.jit.trace(model, (x, y))
# Copyright (c) OpenMMLab. All rights reserved.
import torch
try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close
from mmdeploy.core import MODULE_REWRITER, patch_model
def test_module_rewriter():
from torchvision.models.resnet import resnet50
@MODULE_REWRITER.register_rewrite_module(
module_type='torchvision.models.resnet.Bottleneck', backend='tensorrt')
class BottleneckWrapper(torch.nn.Module):
def __init__(self, module, cfg, **kwargs):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) * 2
x = torch.rand(1, 64, 32, 32)
model = resnet50().eval()
bottle_neck = model.layer1[0]
result = bottle_neck(x)
# rewrite module
cfg = dict()
rewritten_model = patch_model(model, cfg=cfg, backend='tensorrt')
rewritten_bottle_nect = rewritten_model.layer1[0]
rewritten_result = rewritten_bottle_nect(x)
torch_assert_close(rewritten_result, result * 2)
# wrong backend should not be rewritten
model = resnet50().eval()
bottle_neck = model.layer1[0]
result = bottle_neck(x)
rewritten_model = patch_model(model, cfg=cfg)
rewritten_bottle_nect = rewritten_model.layer1[0]
rewritten_result = rewritten_bottle_nect(x)
torch_assert_close(rewritten_result, result)
def test_pass_redundant_args_to_model():
from torchvision.models.resnet import resnet50
@MODULE_REWRITER.register_rewrite_module(
module_type='torchvision.models.resnet.Bottleneck')
class BottleneckWrapper(torch.nn.Module):
def __init__(self, module, cfg):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) * 2
model = resnet50().eval()
rewritten_model = patch_model(model, cfg={}, redundant_args=12345)
assert rewritten_model is not None
# Copyright (c) OpenMMLab. All rights reserved.
import mmdeploy
import mmdeploy.core.rewriters.rewriter_utils as rewriter_utils
from mmdeploy.core.rewriters.rewriter_utils import (BackendChecker,
RewriterRegistry,
collect_env)
from mmdeploy.utils.constants import IR, Backend
def test_collect_env():
env_dict = collect_env(Backend.ONNXRUNTIME, IR.ONNX, version='1.0')
assert env_dict['backend'] == Backend.ONNXRUNTIME
assert env_dict['ir'] == IR.ONNX
assert env_dict['version'] == '1.0'
assert env_dict['mmdeploy'] == mmdeploy.__version__
class TestChecker:
env = collect_env(Backend.ONNXRUNTIME, IR.ONNX)
def test_backend_checker(self):
true_checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME)
assert true_checker.check(self.env) is True
false_checker = rewriter_utils.BackendChecker(Backend.TENSORRT)
assert false_checker.check(self.env) is False
def test_ir_checker(self):
true_checker = rewriter_utils.IRChecker(IR.ONNX)
assert true_checker.check(self.env) is True
false_checker = rewriter_utils.IRChecker(IR.TORCHSCRIPT)
assert false_checker.check(self.env) is False
def test_lib_version_checker(self):
true_checker = rewriter_utils.LibVersionChecker(
'mmdeploy', mmdeploy.__version__, mmdeploy.__version__)
assert true_checker.check(self.env) is True
false_checker = rewriter_utils.LibVersionChecker(
'mmdeploy', max_version='0.0.0')
assert false_checker.check(self.env) is False
def test_register_object():
registry = RewriterRegistry()
checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME)
@registry.register_object(
'add',
backend=Backend.DEFAULT.value,
ir=IR.DEFAULT,
extra_checkers=checker)
def add(a, b):
return a + b
records = registry._rewrite_records
assert records is not None
assert records['add'] is not None
assert isinstance(records['add'][0]['_checkers'], list)
assert isinstance(records['add'][0]['_checkers'][0], BackendChecker)
assert records['add'][0]['_object'] is not None
add_func = records['add'][0]['_object']
assert add_func(123, 456) == 123 + 456
def test_get_records():
registry = RewriterRegistry()
@registry.register_object(
'get_num', backend=Backend.ONNXRUNTIME.value, ir=IR.ONNX)
def get_num_1():
return 1
@registry.register_object(
'get_num', backend=Backend.ONNXRUNTIME.value, ir=IR.TORCHSCRIPT)
def get_num_2():
return 2
@registry.register_object(
'get_num', backend=Backend.TENSORRT.value, ir=IR.ONNX)
def get_num_3():
return 3
@registry.register_object(
'get_num', backend=Backend.TENSORRT.value, ir=IR.TORCHSCRIPT)
def get_num_4():
return 4
@registry.register_object(
'get_num', backend=Backend.DEFAULT.value, ir=IR.DEFAULT)
def get_num_5():
return 5
records = dict(
registry.get_records(collect_env(Backend.ONNXRUNTIME, IR.ONNX)))
assert records['get_num']['_object']() == 1
records = dict(
registry.get_records(collect_env(Backend.ONNXRUNTIME, IR.TORCHSCRIPT)))
assert records['get_num']['_object']() == 2
records = dict(
registry.get_records(collect_env(Backend.TENSORRT, IR.ONNX)))
assert records['get_num']['_object']() == 3
records = dict(
registry.get_records(collect_env(Backend.TENSORRT, IR.TORCHSCRIPT)))
assert records['get_num']['_object']() == 4
records = dict(registry.get_records(collect_env(Backend.NCNN, IR.ONNX)))
assert records['get_num']['_object']() == 5
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import onnx
import pytest
import torch
from torch.autograd import Function
import mmdeploy
from mmdeploy.core import SYMBOLIC_REWRITER, RewriterContext
from mmdeploy.core.rewriters.symbolic_rewriter import SymbolicRewriter
output_file = tempfile.NamedTemporaryFile(suffix='.onnx').name
@pytest.fixture(autouse=True, scope='module')
def create_custom_module():
class TestFunc(Function):
@staticmethod
def symbolic(g, x, val):
return g.op('mmdeploy::symbolic_old', x, val_i=val)
@staticmethod
def forward(ctx, x, val):
return x + val
# put TestFunc in an module so we can found it
# could be any module
mmdeploy.TestFunc = TestFunc
yield
del mmdeploy.TestFunc
def test_symbolic_rewriter():
test_func = mmdeploy.TestFunc.apply
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc', backend='ncnn')
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(g, x, val):
ctx = SYMBOLIC_REWRITER.get_context('mmdeploy.TestFunc')
assert hasattr(ctx, 'cfg')
return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic(
'mmdeploy.TestFunc', backend='tensorrt')
def symbolic_testfunc_tensorrt(g, x, val):
return g.op('mmdeploy::symbolic_testfunc_tensorrt', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(g, input, dim):
return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.cummax(test_func(x, 5), dim=1)
model = TestModel().eval()
# dummy input
x = torch.rand(2, 3, 4)
# default
cfg = dict()
with RewriterContext(cfg=cfg, opset=11):
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmdeploy'
# ncnn
with RewriterContext(cfg=cfg, backend='ncnn', opset=11):
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmdeploy'
# tensorrt
with RewriterContext(cfg=cfg, backend='tensorrt', opset=11):
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_tensorrt'
assert nodes[0].domain == 'mmdeploy'
assert nodes[1].op_type == 'cummax_default'
assert nodes[1].domain == 'mmdeploy'
def test_unregister():
test_func = mmdeploy.TestFunc.apply
@SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc')
def symbolic_testfunc_default(g, x, val):
return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
@SYMBOLIC_REWRITER.register_symbolic(
'cummax', is_pytorch=True, arg_descriptors=['v', 'i'])
def symbolic_cummax(g, input, dim):
return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2)
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.cummax(x, dim=1)
class TestModel2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return test_func(x, 5)
model = TestModel().eval()
x = torch.rand(2, 3, 4)
with RewriterContext(cfg={}, opset=11):
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'cummax_default'
assert nodes[0].domain == 'mmdeploy'
with pytest.raises((ValueError, RuntimeError)):
torch.onnx.export(model, x, output_file, opset_version=11)
model = TestModel2().eval()
with RewriterContext(cfg={}, opset=11):
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_testfunc_default'
assert nodes[0].domain == 'mmdeploy'
torch.onnx.export(model, x, output_file, opset_version=11)
onnx_model = onnx.load(output_file)
nodes = onnx_model.graph.node
assert nodes[0].op_type == 'symbolic_old'
assert nodes[0].domain == 'mmdeploy'
def test_register_empty_symbolic():
symbolic_rewriter = SymbolicRewriter()
@symbolic_rewriter.register_symbolic('mmdeploy.EmptyFunction')
def symbolic_testfunc_default(g, x, val):
return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val)
symbolic_rewriter.enter()
assert len(symbolic_rewriter._extra_symbolic) == 0
symbolic_rewriter.exit()
# Copyright (c) OpenMMLab. All rights reserved.
project(tests)
set(TC_SRCS test_main.cpp)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/archive ARCHIVE_TC)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/core CORE_TC)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/preprocess TRANSFORM_TC)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/net NET_TC)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/model MODEL_TC)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/graph GRAPH_TC)
set(DEVICE_TC)
foreach (DEVICE IN LISTS MMDEPLOY_TARGET_DEVICES)
list(APPEND DEVICE_TC
${CMAKE_CURRENT_SOURCE_DIR}/device/test_${DEVICE}_device.cpp)
endforeach ()
set(CAPI_TC)
if ("all" IN_LIST MMDEPLOY_CODEBASES)
set(TASK_LIST
"classifier;detector;segmentor;text_detector;text_recognizer;restorer;model"
)
set(CODEBASES "mmcls;mmdet;mmseg;mmedit;mmocr")
else ()
set(TASK_LIST "model")
set(CODEBASES "${MMDEPLOY_CODEBASES}")
if ("mmcls" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "classifier")
endif ()
if ("mmdet" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "detector")
endif ()
if ("mmseg" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "segmentor")
endif ()
if ("mmedit" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "restorer")
endif ()
if ("mmocr" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "text_detector")
list(APPEND TASK_LIST "text_recognizer")
endif ()
endif ()
foreach (TASK ${TASK_LIST})
list(APPEND CAPI_TC ${CMAKE_CURRENT_SOURCE_DIR}/capi/test_${TASK}.cpp)
endforeach ()
# generate the header file
configure_file(config/test_define.h.in
${CMAKE_CURRENT_SOURCE_DIR}/test_define.h)
set(TC_SRCS
${TC_SRCS}
${ARCHIVE_TC}
${CORE_TC}
${TRANSFORM_TC}
${MODEL_TC}
${NET_TC}
${DEVICE_TC}
${CAPI_TC}
${GRAPH_TC})
add_executable(mmdeploy_tests ${TC_SRCS})
target_include_directories(mmdeploy_tests
PRIVATE ${CMAKE_SOURCE_DIR}/third_party/catch2)
target_include_directories(mmdeploy_tests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
target_compile_options(mmdeploy_tests PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
endif ()
mmdeploy_load_static(mmdeploy_tests MMDeployStaticModules)
mmdeploy_load_dynamic(mmdeploy_tests MMDeployDynamicModules)
target_link_libraries(mmdeploy_tests PRIVATE
MMDeployLibs
mmdeploy_transform
mmdeploy_operation
mmdeploy_opencv_utils)
// Copyright (c) OpenMMLab. All rights reserved.
#include <deque>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "catch.hpp"
#include "mmdeploy/archive/json_archive.h"
using ArrayLikeTypes = std::tuple<std::vector<int>, std::deque<int>, std::array<int, 15>,
std::list<int>, std::set<int>, std::unordered_set<int>,
std::multiset<int>, std::unordered_multiset<int> >;
TEMPLATE_LIST_TEST_CASE("test array-like", "[archive]", ArrayLikeTypes) {
TestType v{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4};
nlohmann::json json;
mmdeploy::JsonOutputArchive oa(json);
oa(v);
mmdeploy::JsonInputArchive ia(json);
TestType u{};
ia(u);
std::cout << json << std::endl;
REQUIRE(u == v);
}
using MapLikeTypes = std::tuple<
// std::map<int, float>
std::map<int, float>, std::unordered_map<int, float>, std::multimap<int, float>,
std::unordered_multimap<int, float> >;
TEMPLATE_LIST_TEST_CASE("test map-like", "[archive]", MapLikeTypes) {
TestType v{{1, 123.456f}, {1, 222.222f}, {2, 111.222f}, {3, 223.332f}, {3, 1.22e10f}};
nlohmann::json json;
mmdeploy::JsonOutputArchive oa(json);
oa(v);
mmdeploy::JsonInputArchive ia(json);
TestType u;
ia(u);
std::cout << json << std::endl;
REQUIRE(u == v);
}
struct A {
std::vector<int> vec;
std::string str;
friend bool operator==(const A& a, const A& b) { return a.vec == b.vec && a.str == b.str; }
MMDEPLOY_ARCHIVE_MEMBERS(vec, str);
};
TEST_CASE("test struct", "[archive]") {
A a{{1, 2, 3, 4, 5}, "hello"};
nlohmann::json json;
mmdeploy::JsonOutputArchive oa(json);
oa(a);
mmdeploy::JsonInputArchive ia(json);
A b;
ia(b);
REQUIRE(a == b);
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include <array>
#include <deque>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "mmdeploy/archive/value_archive.h"
#include "mmdeploy/core/utils/formatter.h"
// clang-format off
using ArrayLikeTypes =
std::tuple<
std::vector<int>,
std::deque<int>,
std::array<int, 15>,
std::list<int>,
std::set<int>,
std::unordered_set<int>,
std::multiset<int>,
std::unordered_multiset<int>
>;
// clang-format on
TEMPLATE_LIST_TEST_CASE("test array-like for value", "[value]", ArrayLikeTypes) {
TestType v{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4};
mmdeploy::Value value;
mmdeploy::ValueOutputArchive oa(value);
oa(v);
mmdeploy::ValueInputArchive ia(value);
TestType u{};
ia(u);
REQUIRE(u == v);
}
TEST_CASE("test native array for value archive", "[value1]") {
const int a[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int b[10] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
mmdeploy::Value value;
mmdeploy::ValueOutputArchive oa(value);
oa(a);
mmdeploy::ValueInputArchive ia(value);
ia(b);
REQUIRE(std::vector<int>(a, a + 10) == std::vector<int>(b, b + 10));
}
// clang-format off
using MapLikeTypes =
std::tuple<
std::map<int, float>,
std::unordered_map<int, float>,
std::multimap<int, float>,
std::unordered_multimap<int, float>
// std::map<int, float>
>;
// clang-format on
TEMPLATE_LIST_TEST_CASE("test map-like for value archive", "[value]", MapLikeTypes) {
TestType v{{1, 123.456f}, {1, 222.222f}, {2, 111.222f}, {3, 223.332f}, {3, 1.22e10f}};
mmdeploy::Value value;
mmdeploy::ValueOutputArchive oa(value);
oa(v);
mmdeploy::ValueInputArchive ia(value);
TestType u{};
ia(u);
REQUIRE(u == v);
}
struct OuterObject {
int x;
float y;
struct InnerObject {
std::string f;
bool g;
friend bool operator==(const InnerObject& a, const InnerObject& b) {
return a.f == b.f && a.g == b.g;
}
MMDEPLOY_ARCHIVE_MEMBERS(f, g);
};
InnerObject inner;
struct Stl {
std::vector<std::string> s_vec;
std::map<std::string, int> si_map;
friend bool operator==(const Stl& a, const Stl& b) {
return a.s_vec == b.s_vec && a.si_map == b.si_map;
}
MMDEPLOY_ARCHIVE_MEMBERS(s_vec);
};
Stl stl;
friend bool operator==(const OuterObject& a, const OuterObject& b) {
return a.x == b.x && a.y == b.y && a.inner == b.inner;
}
friend bool operator!=(const OuterObject& a, const OuterObject& b) { return !(a == b); }
MMDEPLOY_ARCHIVE_MEMBERS(x, y, inner, stl);
};
TEST_CASE("test schema", "[value]") {
// clang-format off
OuterObject obj {
1,
2,
{"3", false},
{
{"hello", "world", "mmdeploy"},
{{"1", 1}, {"er", 2}, {"three", 3}}
}
};
// clang-format on
mmdeploy::Value value;
mmdeploy::ValueOutputArchive oa(value);
oa(obj);
std::string ff;
mmdeploy::Value v(ff);
REQUIRE(v.is_string());
REQUIRE(value.is_object());
auto& x = value["x"];
REQUIRE(x.is_number_integer());
REQUIRE(x.get<int>() == 1);
auto& y = value["y"];
REQUIRE(y.is_number_float());
REQUIRE(y.get<float>() == 2);
auto& inner = value["inner"];
REQUIRE(inner.is_object());
auto& f = inner["f"];
REQUIRE(f.type() == mmdeploy::ValueType::kString);
REQUIRE(f.is_string());
REQUIRE(f.get<std::string>() == "3");
auto& g = inner["g"];
REQUIRE(g.type() == mmdeploy::ValueType::kBool);
REQUIRE(g.get<bool>() == false);
mmdeploy::ValueInputArchive ia(value);
OuterObject u{};
REQUIRE(obj != u);
ia(u);
REQUIRE(obj == u);
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/classifier.h"
#include "mmdeploy/core/logger.h"
#include "opencv2/opencv.hpp"
#include "test_resource.h"
using namespace std;
TEST_CASE("test classifier's c api", "[.classifier][resource]") {
auto test = [](const std::string& device_name, const std::string& model_path,
const std::vector<std::string>& img_list) {
mmdeploy_classifier_t classifier{nullptr};
auto ret =
mmdeploy_classifier_create_by_path(model_path.c_str(), device_name.c_str(), 0, &classifier);
REQUIRE(ret == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (auto& img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_classification_t* results{nullptr};
int* result_count{nullptr};
ret = mmdeploy_classifier_apply(classifier, mats.data(), (int)mats.size(), &results,
&result_count);
REQUIRE(ret == MMDEPLOY_SUCCESS);
auto result_ptr = results;
MMDEPLOY_INFO("model_path: {}", model_path);
for (auto i = 0; i < (int)mats.size(); ++i) {
MMDEPLOY_INFO("the {}-th classification result: ", i);
for (int j = 0; j < *result_count; ++j, ++result_ptr) {
MMDEPLOY_INFO("\t label: {}, score: {}", result_ptr->label_id, result_ptr->score);
}
}
mmdeploy_classifier_release_result(results, result_count, (int)mats.size());
mmdeploy_classifier_destroy(classifier);
};
auto gResources = MMDeployTestResources::Get();
auto img_lists = gResources.LocateImageResources(fs::path{"mmcls"} / "images");
REQUIRE(!img_lists.empty());
for (auto& backend : gResources.backends()) {
DYNAMIC_SECTION("loop backend: " << backend) {
auto model_list = gResources.LocateModelResources(fs::path{"mmcls/"} / backend);
REQUIRE(!model_list.empty());
for (auto& model_path : model_list) {
for (auto& device_name : gResources.device_names(backend)) {
test(device_name, model_path, img_lists);
}
}
}
}
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/detector.h"
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/utils/formatter.h"
#include "opencv2/opencv.hpp"
#include "test_resource.h"
using namespace std;
TEST_CASE("test detector's c api", "[.detector][resource]") {
MMDEPLOY_INFO("test detector");
auto test = [](const string &device, const string &model_path, const vector<string> &img_list) {
mmdeploy_detector_t detector{nullptr};
auto ret = mmdeploy_detector_create_by_path(model_path.c_str(), device.c_str(), 0, &detector);
REQUIRE(ret == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (auto &img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_detection_t *results{nullptr};
int *result_count{nullptr};
ret = mmdeploy_detector_apply(detector, mats.data(), (int)mats.size(), &results, &result_count);
REQUIRE(ret == MMDEPLOY_SUCCESS);
auto result_ptr = results;
for (auto i = 0; i < mats.size(); ++i) {
MMDEPLOY_INFO("the '{}-th' image has '{}' objects", i, result_count[i]);
for (auto j = 0; j < result_count[i]; ++j, ++result_ptr) {
auto &bbox = result_ptr->bbox;
MMDEPLOY_INFO(" >> bbox[{}, {}, {}, {}], label_id {}, score {}", bbox.left, bbox.top,
bbox.right, bbox.bottom, result_ptr->label_id, result_ptr->score);
}
}
mmdeploy_detector_release_result(results, result_count, (int)mats.size());
mmdeploy_detector_destroy(detector);
};
MMDEPLOY_INFO("get test resources");
auto &gResources = MMDeployTestResources::Get();
MMDEPLOY_INFO("locate image resources");
auto img_lists = gResources.LocateImageResources(fs::path{"mmdet"} / "images");
MMDEPLOY_INFO("{}", img_lists.size());
REQUIRE(!img_lists.empty());
for (auto &backend : gResources.backends()) {
MMDEPLOY_INFO("backend: {}", backend);
DYNAMIC_SECTION("loop backend: " << backend) {
auto model_list = gResources.LocateModelResources(fs::path{"mmdet"} / backend);
REQUIRE(!model_list.empty());
for (auto &model_path : model_list) {
MMDEPLOY_INFO("model: {}", model_path);
for (auto &device_name : gResources.device_names(backend)) {
test(device_name, model_path, img_lists);
}
}
}
}
}
#if 0
TEST_CASE("test detector's c api", "[detector]") {
mm_model_t model{};
// pretend the model is loaded
mm_handle_t handle{};
mmdeploy_async_detector_create(model, "cuda", 0, &handle);
std::vector<mm_mat_t> imgs;
std::vector<mmdeploy_sender_t> sndrs;
for (const auto &img : imgs) {
mmdeploy_value_t value = mmdeploy_async_detector_create_input(&img, 1);
mmdeploy_sender_t input = mmdeploy_executor_just(value);
mmdeploy_sender_t detect = mmdeploy_async_detector_apply(handle, input);
mmdeploy_sender_t started = mmdeploy_executor_ensure_started(detect);
sndrs.push_back(started);
}
for (int i = 0; i < imgs.size(); ++i) {
mmdeploy_value_t output = mmdeploy_executor_sync_wait(sndrs[i]);
mm_detect_t *dets{};
int *count{};
mmdeploy_async_detector_get_result(output, &dets, &count);
mmdeploy_detector_release_result(dets, count, 1);
}
mmdeploy_async_detector_destroy(handle);
}
#endif
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/model.h"
#include "test_resource.h"
TEST_CASE("test model c capi", "[.model][resource]") {
auto &gResource = MMDeployTestResources::Get();
std::string model_path;
for (auto const &codebase : gResource.codebases()) {
for (auto const &backend : gResource.backends()) {
if (auto _model_list = gResource.LocateModelResources(fs::path{codebase} / backend);
!_model_list.empty()) {
model_path = _model_list.front();
break;
}
}
}
REQUIRE(!model_path.empty());
mmdeploy_model_t model{};
REQUIRE(mmdeploy_model_create_by_path(model_path.c_str(), &model) == MMDEPLOY_SUCCESS);
mmdeploy_model_destroy(model);
model = nullptr;
REQUIRE(mmdeploy_model_create(nullptr, 0, &model) == MMDEPLOY_E_FAIL);
mmdeploy_model_destroy(model);
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/restorer.h"
#include "opencv2/opencv.hpp"
#include "test_resource.h"
using namespace std;
TEST_CASE("test restorer's c api", "[.restorer][resource]") {
auto test = [](const string &device, const string &backend, const string &model_path,
const vector<string> &img_list) {
mmdeploy_restorer_t restorer{nullptr};
auto ret = mmdeploy_restorer_create_by_path(model_path.c_str(), device.c_str(), 0, &restorer);
REQUIRE(ret == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (auto &img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_mat_t *res{};
ret = mmdeploy_restorer_apply(restorer, mats.data(), (int)mats.size(), &res);
REQUIRE(ret == MMDEPLOY_SUCCESS);
for (auto i = 0; i < cv_mats.size(); ++i) {
cv::Mat out(res[i].height, res[i].width, CV_8UC3, res[i].data);
cv::cvtColor(out, out, cv::COLOR_RGB2BGR);
cv::imwrite("restorer_" + backend + "_" + to_string(i) + ".bmp", out);
}
mmdeploy_restorer_release_result(res, (int)mats.size());
mmdeploy_restorer_destroy(restorer);
};
auto gResources = MMDeployTestResources::Get();
auto img_lists = gResources.LocateImageResources(fs::path{"mmedit"} / "images");
REQUIRE(!img_lists.empty());
for (auto &backend : gResources.backends()) {
DYNAMIC_SECTION("loop backend: " << backend) {
auto model_list = gResources.LocateModelResources(fs::path{"mmedit"} / backend);
REQUIRE(!model_list.empty());
for (auto &model_path : model_list) {
for (auto &device_name : gResources.device_names(backend)) {
test(device_name, backend, model_path, img_lists);
}
}
}
}
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/segmentor.h"
#include "opencv2/opencv.hpp"
#include "test_resource.h"
using namespace std;
TEST_CASE("test segmentor's c api", "[.segmentor][resource]") {
auto test = [](const string &device, const string &backend, const string &model_path,
const vector<string> &img_list) {
mmdeploy_segmentor_t segmentor{nullptr};
auto ret = mmdeploy_segmentor_create_by_path(model_path.c_str(), device.c_str(), 0, &segmentor);
REQUIRE(ret == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (auto &img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_segmentation_t *results{nullptr};
int count = 0;
ret = mmdeploy_segmentor_apply(segmentor, mats.data(), (int)mats.size(), &results);
REQUIRE(ret == MMDEPLOY_SUCCESS);
REQUIRE(results != nullptr);
auto result_ptr = results;
for (auto i = 0; i < mats.size(); ++i, ++result_ptr) {
cv::Mat mask(result_ptr->height, result_ptr->width, CV_32SC1, result_ptr->mask);
cv::imwrite("mask_" + backend + "_" + to_string(i) + ".png", mask * 10);
}
mmdeploy_segmentor_release_result(results, (int)mats.size());
mmdeploy_segmentor_destroy(segmentor);
};
auto gResources = MMDeployTestResources::Get();
auto img_lists = gResources.LocateImageResources(fs::path{"mmseg"} / "images");
REQUIRE(!img_lists.empty());
for (auto &backend : gResources.backends()) {
DYNAMIC_SECTION("loop backend: " << backend) {
auto model_list = gResources.LocateModelResources(fs::path{"mmseg"} / backend);
REQUIRE(!model_list.empty());
for (auto &model_path : model_list) {
for (auto &device_name : gResources.device_names(backend)) {
test(device_name, backend, model_path, img_lists);
}
}
}
}
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/text_detector.h"
#include "mmdeploy/core/logger.h"
#include "opencv2/opencv.hpp"
#include "test_resource.h"
using namespace std;
TEST_CASE("test text detector's c api", "[.text-detector][resource]") {
auto test = [](const string& device, const string& model_path, const vector<string>& img_list) {
mmdeploy_text_detector_t detector{nullptr};
auto ret =
mmdeploy_text_detector_create_by_path(model_path.c_str(), device.c_str(), 0, &detector);
REQUIRE(ret == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (auto& img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_text_detection_t* results{nullptr};
int* result_count{nullptr};
ret = mmdeploy_text_detector_apply(detector, mats.data(), (int)mats.size(), &results,
&result_count);
REQUIRE(ret == MMDEPLOY_SUCCESS);
auto result_ptr = results;
for (auto i = 0; i < mats.size(); ++i) {
MMDEPLOY_INFO("the {}-th image has '{}' objects", i, result_count[i]);
for (auto j = 0; j < result_count[i]; ++j, ++result_ptr) {
auto& bbox = result_ptr->bbox;
MMDEPLOY_INFO(">> bbox[{}].score: {}, coordinate: ", i, result_ptr->score);
for (auto& _bbox : result_ptr->bbox) {
MMDEPLOY_INFO(">> >> ({}, {})", _bbox.x, _bbox.y);
}
}
}
mmdeploy_text_detector_release_result(results, result_count, (int)mats.size());
mmdeploy_text_detector_destroy(detector);
};
auto& gResources = MMDeployTestResources::Get();
auto img_list = gResources.LocateImageResources(fs::path{"mmocr"} / "images");
REQUIRE(!img_list.empty());
for (auto& backend : gResources.backends()) {
DYNAMIC_SECTION("loop backend: " << backend) {
auto model_list = gResources.LocateModelResources(fs::path{"mmocr"} / "textdet" / "backend");
REQUIRE(!model_list.empty());
for (auto& model_path : model_list) {
for (auto& device_name : gResources.device_names(backend)) {
test(device_name, model_path, img_list);
}
}
}
}
}
// Copyright (c) OpenMMLab. All rights reserved.
// clang-format off
#include "catch.hpp"
// clang-format on
#include "mmdeploy/apis/c/mmdeploy/text_recognizer.h"
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/utils/formatter.h"
#include "opencv2/opencv.hpp"
#include "test_resource.h"
using namespace std;
TEST_CASE("test text recognizer's c api", "[.text-recognizer][resource]") {
auto test = [](const string& device, const string& model_path, const vector<string>& img_list) {
mmdeploy_text_recognizer_t recognizer{nullptr};
auto ret =
mmdeploy_text_recognizer_create_by_path(model_path.c_str(), device.c_str(), 0, &recognizer);
REQUIRE(ret == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (auto& img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_text_recognition_t* results{};
ret = mmdeploy_text_recognizer_apply_bbox(recognizer, mats.data(), (int)mats.size(), nullptr,
nullptr, &results);
REQUIRE(ret == MMDEPLOY_SUCCESS);
for (auto i = 0; i < mats.size(); ++i) {
std::vector<float> score(results[i].score, results[i].score + results[i].length);
MMDEPLOY_INFO("image {}, text = {}, score = {}", i, results[i].text, score);
}
mmdeploy_text_recognizer_release_result(results, (int)mats.size());
mmdeploy_text_recognizer_destroy(recognizer);
};
auto& gResources = MMDeployTestResources::Get();
auto img_list = gResources.LocateImageResources(fs::path{"mmocr"} / "images");
REQUIRE(!img_list.empty());
for (auto& backend : gResources.backends()) {
DYNAMIC_SECTION("loop backend: " << backend) {
auto model_list = gResources.LocateModelResources(fs::path{"mmocr"} / "textreg" / "backend");
REQUIRE(!model_list.empty());
for (auto& model_path : model_list) {
for (auto& device_name : gResources.device_names(backend)) {
test(device_name, model_path, img_list);
}
}
}
}
}
TEST_CASE("test text detector-recognizer combo", "[.text-detector-recognizer]") {
auto test = [](const std::string& device, const string& det_model_path,
const string& reg_model_path, std::vector<string>& img_list) {
mmdeploy_text_detector_t detector{};
REQUIRE(mmdeploy_text_detector_create_by_path(det_model_path.c_str(), device.c_str(), 0,
&detector) == MMDEPLOY_SUCCESS);
mmdeploy_text_recognizer_t recognizer{};
REQUIRE(mmdeploy_text_recognizer_create_by_path(reg_model_path.c_str(), device.c_str(), 0,
&recognizer) == MMDEPLOY_SUCCESS);
vector<cv::Mat> cv_mats;
vector<mmdeploy_mat_t> mats;
for (const auto& img_path : img_list) {
cv::Mat mat = cv::imread(img_path);
REQUIRE(!mat.empty());
cv_mats.push_back(mat);
mats.push_back({mat.data, mat.rows, mat.cols, mat.channels(), MMDEPLOY_PIXEL_FORMAT_BGR,
MMDEPLOY_DATA_TYPE_UINT8});
}
mmdeploy_text_detection_t* bboxes{};
int* bbox_count{};
REQUIRE(mmdeploy_text_detector_apply(detector, mats.data(), mats.size(), &bboxes,
&bbox_count) == MMDEPLOY_SUCCESS);
mmdeploy_text_recognition_t* texts{};
REQUIRE(mmdeploy_text_recognizer_apply_bbox(recognizer, mats.data(), (int)mats.size(), bboxes,
bbox_count, &texts) == MMDEPLOY_SUCCESS);
int offset = 0;
for (auto i = 0; i < mats.size(); ++i) {
for (int j = 0; j < bbox_count[i]; ++j) {
auto& text = texts[offset + j];
std::vector<float> score(text.score, text.score + text.length);
MMDEPLOY_INFO("image {}, text = {}, score = {}", i, text.text, score);
}
offset += bbox_count[i];
}
mmdeploy_text_recognizer_release_result(texts, offset);
mmdeploy_text_detector_release_result(bboxes, bbox_count, offset);
mmdeploy_text_recognizer_destroy(recognizer);
mmdeploy_text_detector_destroy(detector);
};
auto& gResources = MMDeployTestResources::Get();
auto img_list = gResources.LocateImageResources(fs::path{"mmocr"} / "images");
REQUIRE(!img_list.empty());
for (auto& backend : gResources.backends()) {
DYNAMIC_SECTION("loop backend: " << backend) {
auto det_model_list =
gResources.LocateModelResources(fs::path{"mmocr"} / "textdet" / backend);
auto reg_model_list =
gResources.LocateModelResources(fs::path{"mmocr"} / "textreg" / backend);
REQUIRE(!det_model_list.empty());
REQUIRE(!reg_model_list.empty());
auto det_model_path = det_model_list.front();
auto reg_model_path = reg_model_list.front();
for (auto& device_name : gResources.device_names(backend)) {
test(device_name, det_model_path, reg_model_path, img_list);
}
}
}
}
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_TEST_DEFINE_H
#define MMDEPLOY_TEST_DEFINE_H
static constexpr const char *kBackends = "@MMDEPLOY_TARGET_BACKENDS@";
static constexpr const char *kDevices = "@MMDEPLOY_TARGET_DEVICES@";
static constexpr const char *kCodebases = "@CODEBASES@";
#endif // MMDEPLOY_TEST_DEFINE_H
This diff is collapsed.
// Copyright (c) OpenMMLab. All rights reserved.
#include <array>
#include <iostream>
#include <numeric>
#include "catch.hpp"
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/mat.h"
#include "test_resource.h"
using namespace mmdeploy;
using namespace framework;
using namespace std;
TEST_CASE("default mat constructor", "[mat]") {
auto gResource = MMDeployTestResources::Get();
const Device kHost{"cpu"};
SECTION("default constructor") {
Mat mat;
REQUIRE(mat.pixel_format() == PixelFormat::kGRAYSCALE);
REQUIRE(mat.type() == DataType::kINT8);
REQUIRE(mat.height() == 0);
REQUIRE(mat.width() == 0);
REQUIRE(mat.channel() == 0);
REQUIRE(mat.size() == 0);
REQUIRE(mat.byte_size() == 0);
REQUIRE(mat.data<void>() == nullptr);
REQUIRE(mat.device().platform_id() == -1);
}
SECTION("construct with device") {
std::array<PixelFormat, 7> pixel_formats{PixelFormat::kBGR, PixelFormat::kRGB,
PixelFormat::kGRAYSCALE, PixelFormat::kNV12,
PixelFormat::kNV21, PixelFormat::kBGRA};
std::array<DataType, 5> data_types{DataType::kFLOAT, DataType::kHALF, DataType::kINT8,
DataType::kINT32};
int success = 0;
for (auto format : pixel_formats) {
for (auto data_type : data_types) {
Mat mat{100, 200, format, data_type, kHost};
success += (mat.byte_size() > 0);
}
}
REQUIRE(success == pixel_formats.size() * data_types.size());
for (auto &device_name : gResource.device_names()) {
Device device{device_name.c_str()};
REQUIRE_THROWS(Mat{100, 200, PixelFormat(0xff), DataType::kINT8, device});
REQUIRE_THROWS(Mat{100, 200, PixelFormat::kGRAYSCALE, DataType(0xff), device});
}
}
SECTION("construct with data") {
constexpr int kRows = 100;
constexpr int kCols = 200;
vector<uint8_t> data(kRows * kCols, 0);
SECTION("void* data") {
Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, data.data(), kHost};
REQUIRE(mat.byte_size() > 0);
}
SECTION("shared_ptr") {
std::shared_ptr<void> data_ptr(data.data(), [&](void *p) {});
Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, data_ptr, kHost};
REQUIRE(mat.byte_size() > 0);
}
}
}
TEST_CASE("mat constructor in difference devices", "[mat]") {
auto gResource = MMDeployTestResources::Get();
constexpr int kRows = 10;
constexpr int kCols = 10;
constexpr int kSize = kRows * kCols;
vector<uint8_t> data(kSize);
std::iota(data.begin(), data.end(), 1);
for (auto &device_name : gResource.device_names()) {
Device device{device_name.c_str()};
// copy to device
Mat mat{kRows, kCols, PixelFormat::kGRAYSCALE, DataType::kINT8, device};
Stream stream = Stream::GetDefault(device);
REQUIRE(stream.Copy(data.data(), mat.buffer(), mat.buffer().GetSize()));
REQUIRE(stream.Wait());
// copy to host
vector<uint8_t> host_data(mat.size());
REQUIRE(stream.Copy(mat.buffer(), host_data.data(), mat.byte_size()));
REQUIRE(stream.Wait());
// compare data to check if they are the same
int count = 0;
for (size_t i = 0; i < host_data.size(); ++i) {
count += (host_data[i] == data[i]);
}
REQUIRE(count == mat.size());
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment