content_loss.lua 2.66 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
require 'torch'
require 'nn'
local content = {}

function content.defineVGG(content_layer)
  local contentFunc = nn.Sequential()
  require 'loadcaffe'
  require 'util/VGG_preprocess'
  cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn')
  contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
  contentFunc:add(nn.VGG_postprocess())
  for i = 1, #cnn do
    local layer = cnn:get(i):clone()
    local name = layer.name
    local layer_type = torch.type(layer)
    contentFunc:add(layer)
    if name == content_layer then
      print("Setting up content layer: ", layer.name)
      break
    end
  end
  cnn = nil
  collectgarbage()
  print(contentFunc)
  return contentFunc
end

function content.defineAlexNet(content_layer)
  local contentFunc = nn.Sequential()
  require 'loadcaffe'
  require 'util/VGG_preprocess'
  cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn')
  contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
  contentFunc:add(nn.VGG_postprocess())
  for i = 1, #cnn do
    local layer = cnn:get(i):clone()
    local name = layer.name
    local layer_type = torch.type(layer)
    contentFunc:add(layer)
    if name == content_layer then
      print("Setting up content layer: ", layer.name)
      break
    end
  end
  cnn = nil
  collectgarbage()
  print(contentFunc)
  return contentFunc
end



function content.defineContent(content_loss, layer_name)
  -- print('content_loss_define', content_loss)
  if content_loss == 'pixel' or content_loss == 'none' then
    return nil
  elseif content_loss == 'vgg' then
    return content.defineVGG(layer_name)
  else
    print("unsupported content loss")
    return nil
  end
end


function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
  if loss_type == 'none' then
    local errCont = 0.0
    local df_d_content = torch.zeros(fake_target:size())
    return errCont, df_d_content
  elseif loss_type == 'pixel' then
    local errCont = criterionContent:forward(fake_target, real_source) * weight
    local df_do_content = criterionContent:backward(fake_target, real_source)*weight
    return errCont, df_do_content
  elseif loss_type == 'vgg' then
    local f_fake = contentFunc:forward(fake_target):clone()
	  local f_real = contentFunc:forward(real_source):clone()
    local errCont = criterionContent:forward(f_fake, f_real) * weight
    local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight
    local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight)
    return errCont, df_do_content
  else error("unsupported content loss")
  end
end


return content