Unverified Commit 8547b21c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #4760 from microsoft/dev-oneshot

[DO NOT SQUASH] One-shot as strategy
parents 58d205d3 2355bacb
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
try:
from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot
)
except ImportError as import_err:
_import_err = import_err
class ImportFailedStrategy(BaseStrategy):
def run(self, base_model, applied_mutators):
raise _import_err
# otherwise typing check will pointing to the wrong location
globals()['DARTS'] = ImportFailedStrategy
globals()['GumbelDARTS'] = ImportFailedStrategy
globals()['Proxyless'] = ImportFailedStrategy
globals()['ENAS'] = ImportFailedStrategy
globals()['RandomOneShot'] = ImportFailedStrategy
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pytest
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data.sampler import RandomSampler
from torch.utils.data import Dataset, RandomSampler
from nni.retiarii.evaluator.pytorch.lightning import Classification, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from nni.retiarii.oneshot.pytorch import (ConcatenateTrainValDataLoader,
DartsModule, EnasModule, SNASModule,
InterleavedTrainValDataLoader,
ProxylessModule, RandomSampleModule)
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
class DepthwiseSeparableConv(nn.Module):
......@@ -26,119 +24,262 @@ class DepthwiseSeparableConv(nn.Module):
return self.pointwise(self.depthwise(x))
class Net(pl.LightningModule):
def __init__(self):
@model_wrapper
class SimpleNet(nn.Module):
def __init__(self, value_choice=True):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = LayerChoice([
nn.Conv2d(32, 64, 3, 1),
DepthwiseSeparableConv(32, 64)
])
self.dropout1 = nn.Dropout(.25)
self.dropout2 = nn.Dropout(0.5)
self.dropout_choice = InputChoice(2, 1)
self.fc = LayerChoice([
nn.Sequential(
nn.Linear(9216, 64),
nn.ReLU(),
nn.Linear(64, 10),
),
nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Linear(128, 10),
),
nn.Sequential(
nn.Linear(9216, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.dropout2 = nn.Dropout(0.5)
if value_choice:
hidden = nn.ValueChoice([32, 64, 128])
else:
hidden = 64
self.fc1 = nn.Linear(9216, hidden)
self.fc2 = nn.Linear(hidden, 10)
self.rpfc = nn.Linear(10, 10)
self.input_ch = InputChoice(2, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(self.conv2(x), 2)
x1 = torch.flatten(self.dropout1(x), 1)
x2 = torch.flatten(self.dropout2(x), 1)
x = self.dropout_choice([x1, x2])
x = self.fc(x)
x = self.rpfc(x)
x = torch.flatten(self.dropout1(x), 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
x1 = self.rpfc(x)
x = self.input_ch([x, x1])
output = F.log_softmax(x, dim=1)
return output
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def prepare_model_data():
base_model = Net()
@model_wrapper
class MultiHeadAttentionNet(nn.Module):
def __init__(self, head_count):
super().__init__()
embed_dim = ValueChoice(candidates=[32, 64])
self.linear1 = nn.Linear(128, embed_dim)
self.mhatt = nn.MultiheadAttention(embed_dim, head_count)
self.linear2 = nn.Linear(embed_dim, 1)
def forward(self, batch):
query, key, value = batch
q, k, v = self.linear1(query), self.linear1(key), self.linear1(value)
output, _ = self.mhatt(q, k, v, need_weights=False)
y = self.linear2(output)
return F.relu(y)
@model_wrapper
class ValueChoiceConvNet(nn.Module):
def __init__(self):
super().__init__()
ch1 = ValueChoice([16, 32])
kernel = ValueChoice([3, 5])
self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
self.batch_norm = nn.BatchNorm2d(ch1)
self.conv2 = nn.Conv2d(ch1, 64, 3)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.fc = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.batch_norm(x)
x = F.relu(x)
x = F.max_pool2d(self.conv2(x), 2)
x = torch.mean(x, (2, 3))
x = self.fc(x)
return F.log_softmax(x, dim=1)
@model_wrapper
class RepeatNet(nn.Module):
def __init__(self):
super().__init__()
ch1 = ValueChoice([16, 32])
kernel = ValueChoice([3, 5])
self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
self.batch_norm = nn.BatchNorm2d(ch1)
self.conv2 = nn.Conv2d(ch1, 64, 3, padding=1)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.fc = nn.Linear(64, 10)
self.rpfc = nn.Repeat(nn.Linear(10, 10), (1, 4))
def forward(self, x):
x = self.conv1(x)
x = self.batch_norm(x)
x = F.relu(x)
x = F.max_pool2d(self.conv2(x), 2)
x = torch.mean(x, (2, 3))
x = self.fc(x)
x = self.rpfc(x)
return F.log_softmax(x, dim=1)
@basic_unit
class MyOp(nn.Module):
def __init__(self, some_ch):
super().__init__()
self.some_ch = some_ch
self.batch_norm = nn.BatchNorm2d(some_ch)
def forward(self, x):
return self.batch_norm(x)
@model_wrapper
class CustomOpValueChoiceNet(nn.Module):
def __init__(self):
super().__init__()
ch1 = ValueChoice([16, 32])
kernel = ValueChoice([3, 5])
self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
self.batch_norm = MyOp(ch1)
self.conv2 = nn.Conv2d(ch1, 64, 3, padding=1)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.fc = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.batch_norm(x)
x = F.relu(x)
x = F.max_pool2d(self.conv2(x), 2)
x = torch.mean(x, (2, 3))
x = self.fc(x)
return F.log_softmax(x, dim=1)
def _mnist_net(type_):
if type_ == 'simple':
base_model = SimpleNet(False)
elif type_ == 'simple_value_choice':
base_model = SimpleNet()
elif type_ == 'value_choice':
base_model = ValueChoiceConvNet()
elif type_ == 'repeat':
base_model = RepeatNet()
elif type_ == 'custom_op':
base_model = CustomOpValueChoiceNet()
else:
raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train = True, download=True, transform=transform)
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 10))
train_loader = DataLoader(train_dataset, 64, sampler = train_random_sampler)
valid_dataset = MNIST('data/mnist', train = False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 10))
valid_loader = DataLoader(valid_dataset, 64, sampler = valid_random_sampler)
train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform)
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20))
train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, max_epochs=1)
return base_model, evaluator
def _multihead_attention_net():
base_model = MultiHeadAttentionNet(1)
class AttentionRandDataset(Dataset):
def __init__(self, data_shape, gt_shape, len) -> None:
super().__init__()
self.datashape = data_shape
self.gtshape = gt_shape
self.len = len
def __getitem__(self, index):
q = torch.rand(self.datashape)
k = torch.rand(self.datashape)
v = torch.rand(self.datashape)
gt = torch.rand(self.gtshape)
return (q, k, v), gt
def __len__(self):
return self.len
train_set = AttentionRandDataset((1, 128), (1, 1), 1000)
val_set = AttentionRandDataset((1, 128), (1, 1), 500)
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, max_epochs=1)
return base_model, evaluator
def _test_strategy(strategy_, support_value_choice=True):
to_test = [
# (model, evaluator), support_or_net
(_mnist_net('simple'), True),
(_mnist_net('simple_value_choice'), support_value_choice),
(_mnist_net('value_choice'), support_value_choice),
(_mnist_net('repeat'), False), # no strategy supports repeat currently
(_mnist_net('custom_op'), False), # this is definitely a NO
(_multihead_attention_net(), support_value_choice),
]
for (base_model, evaluator), support_or_not in to_test:
print('Testing:', type(strategy_).__name__, type(base_model).__name__, type(evaluator).__name__, support_or_not)
experiment = RetiariiExperiment(base_model, evaluator, strategy=strategy_)
trainer_kwargs = {
'max_epochs' : 1
}
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
return base_model, train_loader, valid_loader, trainer_kwargs
if support_or_not:
experiment.run(config)
assert isinstance(experiment.export_top_models()[0], dict)
else:
with pytest.raises(TypeError, match='not supported'):
experiment.run(config)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_darts():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders = valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
darts_model = DartsModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(darts_model, para_loader)
_test_strategy(strategy.DARTS())
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_proxyless():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = ProxylessModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
_test_strategy(strategy.Proxyless(), False)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_enas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
enas_model = EnasModule(cls.module)
concat_loader = ConcatenateTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(enas_model, concat_loader)
_test_strategy(strategy.ENAS())
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_random():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader , **trainer_kwargs)
cls.module.set_model(base_model)
random_model = RandomSampleModule(cls.module)
cls.trainer.fit(random_model, cls.train_dataloader, cls.val_dataloaders)
_test_strategy(strategy.RandomOneShot())
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_snas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = SNASModule(cls.module, 1, use_temp_anneal=True)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_gumbel_darts():
_test_strategy(strategy.GumbelDARTS())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='exp to run, default = all' )
help='experiment to run, default = all')
args = parser.parse_args()
if args.exp == 'all':
......@@ -146,6 +287,6 @@ if __name__ == '__main__':
test_proxyless()
test_enas()
test_random()
test_snas()
test_gumbel_darts()
else:
globals()[f'test_{args.exp}']()
import pytest
import numpy as np
import torch
import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax
)
from nni.retiarii.oneshot.pytorch.supermodule.sampling import (
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput
)
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS
from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import *
def test_slice():
weight = np.ones((3, 7, 24, 23))
assert S(weight)[:, 1:3, :, 9:13].shape == (3, 2, 24, 4)
assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4)
assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23)
# no effect
assert S(weight)[:] is weight
# list
assert S(weight)[[slice(1), slice(2, 3)]].shape == (2, 7, 24, 23)
assert S(weight)[[slice(1), slice(2, W(2) + 1)], W(2):].shape == (2, 5, 24, 23)
# weighted
weight = S(weight)[:W({1: 0.5, 2: 0.3, 3: 0.2})]
weight = weight[:, 0, 0, 0]
assert weight[0] == 1 and weight[1] == 0.5 and weight[2] == 0.2
weight = np.ones((3, 6, 6))
value = W({1: 0.5, 3: 0.5})
weight = S(weight)[:, 3 - value:3 + value, 3 - value:3 + value]
for i in range(0, 6):
for j in range(0, 6):
if 2 <= i <= 3 and 2 <= j <= 3:
assert weight[0, i, j] == 1
else:
assert weight[1, i, j] == 0.5
# weighted + list
value = W({1: 0.5, 3: 0.5})
weight = np.ones((8, 4))
weight = S(weight)[[slice(value), slice(4, value + 4)]]
assert weight.sum(1).tolist() == [4, 2, 2, 0, 4, 2, 2, 0]
with pytest.raises(ValueError, match='one distinct'):
# has to be exactly the same instance, equal is not enough
weight = S(weight)[:W({1: 0.5}), : W({1: 0.5})]
def test_valuechoice_utils():
chosen = {"exp": 3, "add": 1}
vc0 = ValueChoice([3, 4, 6], label='exp') * 2 + ValueChoice([0, 1], label='add')
assert evaluate_value_choice_with_dict(vc0, chosen) == 7
vc = vc0 + ValueChoice([3, 4, 6], label='exp')
assert evaluate_value_choice_with_dict(vc, chosen) == 10
assert list(dedup_inner_choices([vc0, vc]).keys()) == ['exp', 'add']
assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19]
weights = dict(traverse_all_options(vc, weights={'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}))
ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08), (19, 0.12)])
assert len(weights) == len(ans)
for value, weight in ans.items():
assert abs(weight - weights[value]) < 1e-6
def test_pathsampling_valuechoice():
orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3)
conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
conv.resample(memo={'123': 5})
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5
conv.resample(memo={'123': 7})
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7
assert conv.export({})['123'] in [3, 5, 7]
def test_differentiable_valuechoice():
orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='456'), kernel_size=ValueChoice(
[3, 5, 7], label='123'), padding=ValueChoice([3, 5, 7], label='123') // 2)
conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7
assert set(conv.export({}).keys()) == {'123', '456'}
def _mixed_operation_sampling_sanity_check(operation, memo, *input):
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
break
mutate_op.resample(memo=memo)
return mutate_op(*input)
def _mixed_operation_differentiable_sanity_check(operation, *input):
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
break
return mutate_op(*input)
def test_mixed_linear():
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
_mixed_operation_sampling_sanity_check(linear, {'shared': 9}, torch.randn(2, 9))
_mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9))
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=False)
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
with pytest.raises(TypeError):
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=ValueChoice([False, True]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
def test_mixed_conv2d():
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out') * 2, 1)
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'out': 4}, torch.randn(2, 3, 9, 9)).size(1) == 8
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
# stride
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out'), 1, stride=ValueChoice([1, 2], label='stride'))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 2}, torch.randn(2, 3, 10, 10)).size(2) == 5
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 1}, torch.randn(2, 3, 10, 10)).size(2) == 10
# groups, dw conv
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])
# make sure kernel is sliced correctly
conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False)
conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
with torch.no_grad():
conv.weight.zero_()
# only center is 1, must pick center to pass this test
conv.weight[0, 0, 1, 1] = 1
conv.resample({'k': 1})
assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
def test_mixed_batchnorm2d():
bn = BatchNorm2d(ValueChoice([32, 64], label='dim'))
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 32}, torch.randn(2, 32, 3, 3)).size(1) == 32
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 64}, torch.randn(2, 64, 3, 3)).size(1) == 64
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
def test_mixed_mhattn():
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4)
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8},
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), ValueChoice([2, 3, 4], label='heads'))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 2},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
with pytest.raises(AssertionError, match='divisible'):
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 3},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4, kdim=ValueChoice([5, 7], label='kdim'))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7},
torch.randn(7, 2, 4), torch.randn(7, 2, 7), torch.randn(7, 2, 4))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 5},
torch.randn(7, 2, 8), torch.randn(7, 2, 5), torch.randn(7, 2, 8))[0].size(-1) == 8
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4, vdim=ValueChoice([5, 8], label='vdim'))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'vdim': 8},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 8))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'vdim': 5},
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 5))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8))
@pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch')
def test_mixed_mhattn_batch_first():
# batch_first is not supported for legacy pytorch versions
# mark 1.7 because 1.7 is used on legacy pipeline
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 2, kdim=(ValueChoice([3, 7], label='kdim')), vdim=ValueChoice([5, 8], label='vdim'),
bias=False, add_bias_kv=True, batch_first=True)
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7, 'vdim': 8},
torch.randn(2, 7, 4), torch.randn(2, 7, 7), torch.randn(2, 7, 8))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 3, 'vdim': 5},
torch.randn(2, 7, 8), torch.randn(2, 7, 3), torch.randn(2, 7, 5))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8), torch.randn(1, 7, 7), torch.randn(1, 7, 8))
def test_pathsampling_layer_input():
op = PathSamplingLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], label='ccc')
with pytest.raises(RuntimeError, match='sample'):
op(torch.randn(4, 2))
op.resample({})
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.search_space_spec()['ccc'].values == ['a', 'b']
assert op.export({})['ccc'] in ['a', 'b']
input = PathSamplingInput(5, 2, 'concat', 'ddd')
sample = input.resample({})
assert 'ddd' in sample
assert len(sample['ddd']) == 2
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 4
assert len(input.export({})['ddd']) == 2
def test_differentiable_layer_input():
op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 3
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2
def test_proxyless_layer_input():
op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op.resample({})['eee'] in ['a', 'b']
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 3
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment