Commit b77b98a8 authored by pangjm's avatar pangjm
Browse files

update resnet

parent d6b3cbd5
...@@ -219,24 +219,24 @@ class ResNet(nn.Module): ...@@ -219,24 +219,24 @@ class ResNet(nn.Module):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth)) raise KeyError('invalid depth {} for resnet'.format(depth))
self.depth = depth, self.depth = depth
self.num_stages = num_stages, self.num_stages = num_stages
self.strides = strides,
self.dilations = dilations,
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
self.block, self.stage_blocks = self.arch_settings[depth] self.strides = strides
self.stage_blocks = self.stage_blocks[:num_stages] self.dilations = dilations
assert len(strides) == len(dilations) == num_stages assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
self.out_indices = out_indices self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style self.style = style
self.frozen_stages = frozen_stages self.frozen_stages = frozen_stages
self.bn_eval = bn_eval self.bn_eval = bn_eval
self.bn_frozen = bn_frozen self.bn_frozen = bn_frozen
self.with_cp = with_cp self.with_cp = with_cp
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64 self.inplanes = 64
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment