Commit b210695f authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge branch 'master' into dev-pruner-dataparallel

parents c7d58033 fdfff50d
......@@ -23,10 +23,13 @@ trial:
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The token to login pai
token: token
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_sklearn
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: paiYarn
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner,MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
......@@ -23,10 +23,13 @@ trial:
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
nniManagerNFSMountPath: /home/user/mnt
containerNFSMountPath: /mnt/data/user
paiStoragePlugin: team_wise
paiConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The token to login pai
token: token
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
authorName: default
experimentName: example_sklearn
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: paiYarn
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
cpuNum: 1
memoryMB: 8196
#The docker image to run nni job on pai
image: msranni/nni:latest
paiYarnConfig:
#The username to login pai
userName: username
#The password to login pai
passWord: password
#The host of restful server of pai
host: 10.10.10.10
\ No newline at end of file
......@@ -4,13 +4,11 @@
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util';
import * as component from '../common/component';
import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo';
import { getLogDir } from './utils';
const FATAL: number = 1;
const ERROR: number = 2;
......@@ -55,23 +53,21 @@ class BufferSerialEmitter {
@component.Singleton
class Logger {
private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
private level: number = INFO;
private bufferSerialEmitter: BufferSerialEmitter;
private writable: Writable;
private bufferSerialEmitter?: BufferSerialEmitter;
private writable?: Writable;
private readonly: boolean = false;
constructor(fileName?: string) {
let logFile: string | undefined = fileName;
if (logFile === undefined) {
logFile = this.DEFAULT_LOGFILE;
const logFile: string | undefined = fileName;
if (logFile) {
this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
}
this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
const logLevelName: string = getExperimentStartupInfo()
.getLogLevel();
......@@ -84,7 +80,9 @@ class Logger {
}
public close(): void {
this.writable.destroy();
if (this.writable) {
this.writable.destroy();
}
}
public trace(...param: any[]): void {
......@@ -128,12 +126,15 @@ class Logger {
*/
private log(level: string, param: any[]): void {
if (!this.readonly) {
const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write(`[${(new Date()).toLocaleString()}] ${level} `);
buffer.write(format(param));
buffer.write('\n');
buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents());
const logContent = `[${(new Date()).toLocaleString()}] ${level} ${format(param)}\n`;
if (this.writable && this.bufferSerialEmitter) {
const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write(logContent);
buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents());
} else {
console.log(logContent);
}
}
}
}
......
......@@ -6,6 +6,7 @@
import { Container, Scope } from 'typescript-ioc';
import * as fs from 'fs';
import * as path from 'path';
import * as component from './common/component';
import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo';
......@@ -34,7 +35,7 @@ function initStartupInfo(
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
}
async function initContainer(platformMode: string, logFileName?: string): Promise<void> {
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'local') {
Container.bind(TrainingService)
.to(LocalTrainingService)
......@@ -71,6 +72,12 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
Container.bind(DataStore)
.to(NNIDataStore)
.scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) {
logFileName = undefined;
} else if (logFileName === undefined) {
logFileName = DEFAULT_LOGFILE;
}
Container.bind(Logger).provider({
get: (): Logger => new Logger(logFileName)
});
......@@ -81,7 +88,7 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
function usage(): void {
console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id>');
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
}
const strPort: string = parseArg(['--port', '-p']);
......@@ -90,6 +97,14 @@ if (!strPort || strPort.length === 0) {
process.exit(1);
}
const foregroundArg: string = parseArg(['--foreground', '-f']);
if (!('true' || 'false').includes(foregroundArg.toLowerCase())) {
console.log(`FATAL: foreground property should only be true or false`);
usage();
process.exit(1);
}
const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : false;
const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']);
......@@ -138,7 +153,7 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir())
.then(async () => {
try {
await initContainer(mode);
await initContainer(foreground, mode);
const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start();
const log: Logger = getLogger();
......@@ -162,6 +177,15 @@ function getStopSignal(): any {
}
}
function getCtrlCSignal(): any {
return 'SIGINT';
}
process.on(getCtrlCSignal(), async () => {
const log: Logger = getLogger();
log.info(`Get SIGINT signal!`);
});
process.on(getStopSignal(), async () => {
const log: Logger = getLogger();
let hasError: boolean = false;
......
......@@ -13,6 +13,7 @@
"azure-storage": "^2.10.2",
"chai-as-promised": "^7.1.1",
"child-process-promise": "^2.2.1",
"deepmerge": "^4.2.2",
"express": "^4.16.3",
"express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9",
......
......@@ -38,6 +38,7 @@ export namespace ValidationSchemas {
authFile: joi.string(),
nniManagerNFSMountPath: joi.string().min(1),
containerNFSMountPath: joi.string().min(1),
paiConfigPath: joi.string(),
paiStoragePlugin: joi.string().min(1),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
portList: joi.array().items(joi.object({
......
......@@ -31,10 +31,11 @@ export class NNIPAIK8STrialConfig extends TrialConfig {
public readonly nniManagerNFSMountPath: string;
public readonly containerNFSMountPath: string;
public readonly paiStoragePlugin: string;
public readonly paiConfigPath?: string;
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string,
paiStoragePlugin: string, virtualCluster?: string) {
paiStoragePlugin: string, virtualCluster?: string, paiConfigPath?: string) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
......@@ -43,5 +44,6 @@ export class NNIPAIK8STrialConfig extends TrialConfig {
this.nniManagerNFSMountPath = nniManagerNFSMountPath;
this.containerNFSMountPath = containerNFSMountPath;
this.paiStoragePlugin = paiStoragePlugin;
this.paiConfigPath = paiConfigPath;
}
}
......@@ -44,6 +44,7 @@ import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import { PAIJobRestServer } from '../paiJobRestServer';
const yaml = require('js-yaml');
const deepmerge = require('deepmerge');
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
......@@ -59,6 +60,10 @@ class PAIK8STrainingService extends PAITrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
......@@ -185,7 +190,19 @@ class PAIK8STrainingService extends PAITrainingService {
}
}
return yaml.safeDump(paiJobConfig);
if (this.paiTrialConfig.paiConfigPath) {
try {
const additionalPAIConfig = yaml.safeLoad(fs.readFileSync(this.paiTrialConfig.paiConfigPath, 'utf8'));
//deepmerge(x, y), if an element at the same key is present for both x and y, the value from y will appear in the result.
//refer: https://github.com/TehShrike/deepmerge
const overwriteMerge = (destinationArray: any, sourceArray: any, options: any) => sourceArray;
return yaml.safeDump(deepmerge(additionalPAIConfig, paiJobConfig, { arrayMerge: overwriteMerge }));
} catch (error) {
this.log.error(`Error occurs during loading and merge ${this.paiTrialConfig.paiConfigPath} : ${error}`);
}
} else {
return yaml.safeDump(paiJobConfig);
}
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
......@@ -254,7 +271,7 @@ class PAIK8STrainingService extends PAITrainingService {
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobId, nniPaiTrialCommand);
this.log.debug(paiJobConfig);
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
......
......@@ -1112,6 +1112,11 @@ deepmerge@^2.1.1:
version "2.2.1"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-2.2.1.tgz#5d3ff22a01c00f645405a2fbc17d0778a1801170"
deepmerge@^4.2.2:
version "4.2.2"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955"
integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg==
default-require-extensions@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/default-require-extensions/-/default-require-extensions-2.0.0.tgz#f5f8fbb18a7d6d50b21f641f649ebb522cfe24f7"
......
......@@ -113,13 +113,13 @@ class AGP_Pruner(Pruner):
if not (self.now_epoch >= start_epoch and (self.now_epoch - start_epoch) % freq == 0):
return None
mask = {'weight': torch.ones(weight.shape).type_as(weight)}
mask = {'weight': kwargs['weight_mask'] if 'weight_mask' in kwargs else torch.ones(weight.shape).type_as(weight)}
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs()
w_abs = weight.abs() * mask['weight']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
if_calculated.copy_(torch.tensor(1)) # pylint: disable=not-callable
......
......@@ -31,11 +31,11 @@ def test():
# [1,1,1,1,1,1,1,1,1,1],
# [1,1,1,1,1,1,1,1,1,1]]
assessor = MedianstopAssessor(FLAGS.start_step, FLAGS.optimize_mode)
for i in range(4):
assessor = MedianstopAssessor(FLAGS.optimize_mode, FLAGS.start_step)
for i in range(len(lcs)):
#lc = []
to_complete = True
for k in range(10):
for k in range(len(lcs[0])):
#d = random.randint(i*100+0, i*100+100)
#lc.append(d)
ret = assessor.assess_trial(i, lcs[i][:k+1])
......
......@@ -67,6 +67,13 @@ class ClassicMutator(Mutator):
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
......@@ -162,6 +169,8 @@ class ClassicMutator(Mutator):
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
......
......@@ -63,18 +63,23 @@ class DartsMutator(Mutator):
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice) and mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
if isinstance(mutable, InputChoice):
if mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
# If an edge is never selected, there is no need to calculate any op on this edge.
# This is to eliminate redundant calculation.
result[src_key] = torch.zeros_like(result[src_key])
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
else:
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result
......@@ -52,24 +52,24 @@ def _encode_tensor(data):
return data
def apply_fixed_architecture(model, fixed_arc_path):
def apply_fixed_architecture(model, fixed_arc):
"""
Load architecture from `fixed_arc_path` and apply to model.
Load architecture from `fixed_arc` and apply to model.
Parameters
----------
model : torch.nn.Module
Model with mutables.
fixed_arc_path : str
Path to the JSON that stores the architecture.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
Returns
-------
FixedArchitecture
"""
if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f:
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc)
......
......@@ -17,6 +17,14 @@ def global_mutable_counting():
return _counter
def _reset_global_mutable_counting():
"""
Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
"""
global _counter
_counter = 0
def to_device(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutable_scope import SpaceWithMutableScope
from .naive import NaiveSearchSpace
from .nested import NestedSpace
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class Cell(MutableScope):
def __init__(self, cell_name, prev_labels, channels):
super().__init__(cell_name)
self.input_choice = InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = LayerChoice([
nn.Conv2d(channels, channels, 3, padding=1),
nn.Conv2d(channels, channels, 5, padding=2),
nn.MaxPool2d(3, stride=1, padding=1),
nn.AvgPool2d(3, stride=1, padding=1),
nn.Identity()
], key=cell_name + "_op")
def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(MutableScope):
def __init__(self, node_name, prev_node_names, channels):
super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
def forward(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y
class Layer(nn.Module):
def __init__(self, num_nodes, channels):
super().__init__()
self.num_nodes = num_nodes
self.nodes = nn.ModuleList()
node_labels = [InputChoice.NO_KEY, InputChoice.NO_KEY]
for i in range(num_nodes):
node_labels.append("node_{}".format(i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], channels))
self.final_conv_w = nn.Parameter(torch.zeros(channels, self.num_nodes + 2, channels, 1, 1),
requires_grad=True)
self.bn = nn.BatchNorm2d(channels, affine=False)
def forward(self, pprev, prev):
prev_nodes_out = [pprev, prev]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask.to(prev.device)
# NOTE: which device should we put mask on?
prev_nodes_out.append(node_out)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes)
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1)
out = F.conv2d(unused_nodes, conv_weight)
return prev, self.bn(out)
class SpaceWithMutableScope(nn.Module):
def __init__(self, test_case, num_layers=4, num_nodes=5, channels=16, in_channels=3, num_classes=10):
super().__init__()
self.test_case = test_case
self.num_layers = num_layers
self.stem = nn.Sequential(
nn.Conv2d(in_channels, channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(channels)
)
self.layers = nn.ModuleList()
for _ in range(self.num_layers + 2):
self.layers.append(Layer(num_nodes, channels))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(channels, num_classes)
def forward(self, x):
prev = cur = self.stem(x)
for layer in self.layers:
prev, cur = layer(prev, cur)
cur = self.gap(F.relu(cur)).view(x.size(0), -1)
return self.dense(cur)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class NaiveSearchSpace(nn.Module):
def __init__(self, test_case):
super().__init__()
self.test_case = test_case
self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)],
return_mask=True)
self.conv3 = nn.Conv2d(16, 16, 1)
self.skipconnect = InputChoice(n_candidates=1)
self.skipconnect2 = InputChoice(n_candidates=2, return_mask=True)
self.bn = nn.BatchNorm2d(16)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 10)
def forward(self, x):
bs = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x0, mask = self.conv2(x)
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x1 = F.relu(self.conv3(x0))
_, mask = self.skipconnect2([x0, x1])
x0 = self.skipconnect([x0])
if x0 is not None:
x1 += x0
x = self.pool(self.bn(x1))
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x = self.gap(x).view(bs, -1)
x = self.fc(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment