"docs/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "4c5c5ad20c2023f292b37336bb7ce4a15c665cad"
Unverified Commit 9befc398 authored by lml131's avatar lml131 Committed by GitHub
Browse files

Lml/jit decorator (#673)

* add jit decorator

* add parrots_jit.py

* modify test_parrots_jit.py

* modify for lint

* fix isort

* skip test_parrots_jit.py when build without pytorch

* try ci

* rm log

* fix double quote

* modify for comments and use partial_shape instead of full_shape

* fix for lint

* small modify for parrots 0.9.0rc0

* def skip no elena directly
parent e30dc4f8
...@@ -47,7 +47,7 @@ jobs: ...@@ -47,7 +47,7 @@ jobs:
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report - name: Run unittests and generate coverage report
run: | run: |
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py pytest tests/ --ignore=tests/test_runner --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py
build_without_ops: build_without_ops:
runs-on: ubuntu-latest runs-on: ubuntu-latest
......
...@@ -33,6 +33,7 @@ else: ...@@ -33,6 +33,7 @@ else:
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
from .parrots_jit import jit, skip_no_elena, skip_no_parrots
from .registry import Registry, build_from_cfg from .registry import Registry, build_from_cfg
__all__ = [ __all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
...@@ -48,5 +49,6 @@ else: ...@@ -48,5 +49,6 @@ else:
'_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension', '_InstanceNorm', '_MaxPoolNd', 'get_build_config', 'BuildExtension',
'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader', 'CppExtension', 'CUDAExtension', 'DataLoader', 'PoolDataLoader',
'TORCH_VERSION', 'deprecated_api_warning', 'digit_version', 'TORCH_VERSION', 'deprecated_api_warning', 'digit_version',
'get_git_hash', 'import_modules_from_strings' 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'skip_no_parrots'
] ]
import pytest
import torch
TORCH_VERSION = torch.__version__
if TORCH_VERSION == 'parrots':
from parrots.jit import pat as jit
else:
def jit(func=None,
check_input=None,
full_shape=True,
derivate=False,
coderize=False,
optimize=False):
def wrapper(func):
def wrapper_inner(*args, **kargs):
return func(*args, **kargs)
return wrapper_inner
if func is None:
return wrapper
else:
return func
if TORCH_VERSION == 'parrots':
from parrots.utils.tester import skip_no_elena
else:
def skip_no_elena(func):
def wrapper(*args, **kargs):
return func(*args, **kargs)
return wrapper
def is_using_parrots():
return TORCH_VERSION == 'parrots'
skip_no_parrots = pytest.mark.skipif(
not is_using_parrots(), reason='test case under parrots environment')
import pytest
import torch
import mmcv
class TestJit(object):
def test_add_dict(self):
@mmcv.jit
def add_dict(oper):
rets = oper['x'] + oper['y']
return {'result': rets}
def add_dict_pyfunc(oper):
rets = oper['x'] + oper['y']
return {'result': rets}
a = torch.rand((3, 4))
b = torch.rand((3, 4))
oper = {'x': a, 'y': b}
rets_t = add_dict(oper)
rets = add_dict_pyfunc(oper)
assert 'result' in rets
assert (rets_t['result'] == rets['result']).all()
def test_add_list(self):
@mmcv.jit
def add_list(oper, x, y):
rets = {}
for idx, pair in enumerate(oper):
rets[f'k{idx}'] = pair['x'] + pair['y']
rets[f'k{len(oper)}'] = x + y
return rets
def add_list_pyfunc(oper, x, y):
rets = {}
for idx, pair in enumerate(oper):
rets[f'k{idx}'] = pair['x'] + pair['y']
rets[f'k{len(oper)}'] = x + y
return rets
pair_num = 3
oper = []
for _ in range(pair_num):
oper.append({'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))})
a = torch.rand((3, 4))
b = torch.rand((3, 4))
rets = add_list_pyfunc(oper, x=a, y=b)
rets_t = add_list(oper, x=a, y=b)
for idx in range(pair_num + 1):
assert f'k{idx}' in rets_t
assert (rets[f'k{idx}'] == rets_t[f'k{idx}']).all()
@mmcv.skip_no_parrots
def test_jit_cache(self):
@mmcv.jit
def func(oper):
if oper['const'] > 1:
return oper['x'] * 2 + oper['y']
else:
return oper['x'] * 2 - oper['y']
def pyfunc(oper):
if oper['const'] > 1:
return oper['x'] * 2 + oper['y']
else:
return oper['x'] * 2 - oper['y']
assert len(func._cache._cache) == 0
oper = {'const': 2, 'x': torch.rand((3, 4)), 'y': torch.rand((3, 4))}
rets_plus = pyfunc(oper)
rets_plus_t = func(oper)
assert (rets_plus == rets_plus_t).all()
assert len(func._cache._cache) == 1
oper['const'] = 0.5
rets_minus = pyfunc(oper)
rets_minus_t = func(oper)
assert (rets_minus == rets_minus_t).all()
assert len(func._cache._cache) == 2
rets_a = (rets_minus_t + rets_plus_t) / 4
assert torch.allclose(oper['x'], rets_a)
@mmcv.skip_no_parrots
def test_jit_shape(self):
@mmcv.jit
def func(a):
return a + 1
assert len(func._cache._cache) == 0
a = torch.ones((3, 4))
r = func(a)
assert r.shape == (3, 4)
assert (r == 2).all()
assert len(func._cache._cache) == 1
a = torch.ones((2, 3, 4))
r = func(a)
assert r.shape == (2, 3, 4)
assert (r == 2).all()
assert len(func._cache._cache) == 2
@mmcv.skip_no_parrots
def test_jit_kwargs(self):
@mmcv.jit
def func(a, b):
return torch.mean((a - b) * (a - b))
assert len(func._cache._cache) == 0
x = torch.rand((16, 32))
y = torch.rand((16, 32))
func(x, y)
assert len(func._cache._cache) == 1
func(x, b=y)
assert len(func._cache._cache) == 1
func(b=y, a=x)
assert len(func._cache._cache) == 1
def test_jit_derivate(self):
@mmcv.jit(derivate=True)
def func(x, y):
return (x + 2) * (y - 2)
a = torch.rand((3, 4))
b = torch.rand((3, 4))
a.requires_grad = True
c = func(a, b)
assert c.requires_grad
d = torch.empty_like(c)
d.fill_(1.0)
c.backward(d)
assert torch.allclose(a.grad, (b - 2))
assert b.grad is None
a.grad = None
c = func(a, b)
assert c.requires_grad
d = torch.empty_like(c)
d.fill_(2.7)
c.backward(d)
assert torch.allclose(a.grad, 2.7 * (b - 2))
assert b.grad is None
def test_jit_optimize(self):
@mmcv.jit(optimize=True)
def func(a, b):
return torch.mean((a - b) * (a - b))
def pyfunc(a, b):
return torch.mean((a - b) * (a - b))
a = torch.rand((16, 32))
b = torch.rand((16, 32))
c = func(a, b)
d = pyfunc(a, b)
assert torch.allclose(c, d)
@mmcv.skip_no_elena
def test_jit_coderize(self):
if not torch.cuda.is_available():
return
@mmcv.jit(coderize=True)
def func(a, b):
return (a + b) * (a - b)
def pyfunc(a, b):
return (a + b) * (a - b)
a = torch.rand((16, 32), device='cuda')
b = torch.rand((16, 32), device='cuda')
c = func(a, b)
d = pyfunc(a, b)
assert torch.allclose(c, d)
def test_jit_value_dependent(self):
@mmcv.jit
def func(a, b):
torch.nonzero(a)
return torch.mean((a - b) * (a - b))
def pyfunc(a, b):
torch.nonzero(a)
return torch.mean((a - b) * (a - b))
a = torch.rand((16, 32))
b = torch.rand((16, 32))
c = func(a, b)
d = pyfunc(a, b)
assert torch.allclose(c, d)
@mmcv.skip_no_parrots
def test_jit_check_input(self):
def func(x):
y = torch.rand_like(x)
return x + y
a = torch.ones((3, 4))
with pytest.raises(AssertionError):
func = mmcv.jit(func, check_input=(a, ))
@mmcv.skip_no_parrots
def test_jit_partial_shape(self):
@mmcv.jit(full_shape=False)
def func(a, b):
return torch.mean((a - b) * (a - b))
def pyfunc(a, b):
return torch.mean((a - b) * (a - b))
a = torch.rand((3, 4))
b = torch.rand((3, 4))
assert torch.allclose(func(a, b), pyfunc(a, b))
assert len(func._cache._cache) == 1
a = torch.rand((6, 5))
b = torch.rand((6, 5))
assert torch.allclose(func(a, b), pyfunc(a, b))
assert len(func._cache._cache) == 1
a = torch.rand((3, 4, 5))
b = torch.rand((3, 4, 5))
assert torch.allclose(func(a, b), pyfunc(a, b))
assert len(func._cache._cache) == 2
a = torch.rand((1, 9, 8))
b = torch.rand((1, 9, 8))
assert torch.allclose(func(a, b), pyfunc(a, b))
assert len(func._cache._cache) == 2
def test_instance_method(self):
class T(object):
def __init__(self, shape):
self._c = torch.rand(shape)
@mmcv.jit
def test_method(self, x, y):
return (x * self._c) + y
shape = (16, 32)
t = T(shape)
a = torch.rand(shape)
b = torch.rand(shape)
res = (a * t._c) + b
jit_res = t.test_method(a, b)
assert torch.allclose(res, jit_res)
t = T(shape)
res = (a * t._c) + b
jit_res = t.test_method(a, b)
assert torch.allclose(res, jit_res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment