convert_t7.lua 2.46 KB
Newer Older
dengjf's avatar
dengjf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
require('table')
require('torch')
require('os')

function clone(t)
    -- deep-copy a table
    if type(t) ~= "table" then return t end
    local meta = getmetatable(t)
    local target = {}
    for k, v in pairs(t) do
        if type(v) == "table" then
            target[k] = clone(v)
        else
            target[k] = v
        end
    end
    setmetatable(target, meta)
    return target
end


function tableMerge(lhs, rhs)
    output = clone(lhs)
    for _, v in pairs(rhs) do
        table.insert(output, v)
    end
    return output
end


function isInTable(val, val_list)
    for _, item in pairs(val_list) do
        if val == item then
            return true
        end
    end
    return false
end


function modelToList(model)
    local ignoreList = {
        'nn.Copy',
        'nn.AddConstant',
        'nn.MulConstant',
        'nn.View',
        'nn.Transpose',
        'nn.SplitTable',
        'nn.SharedParallelTable',
        'nn.JoinTable',
    }
    local state = {}
    local param
    for i, layer in pairs(model.modules) do
        local typeName = torch.type(layer)
        if not isInTable(typeName, ignoreList) then
            if typeName == 'nn.Sequential' or typeName == 'nn.ConcatTable' then
                param = modelToList(layer)
            elseif typeName == 'cudnn.SpatialConvolution' or typeName == 'nn.SpatialConvolution' then
                param = layer:parameters()
            elseif typeName == 'cudnn.SpatialBatchNormalization' or typeName == 'nn.SpatialBatchNormalization' then
                param = layer:parameters()
                bn_vars = {layer.running_mean, layer.running_var}
                param = tableMerge(param, bn_vars)
            elseif typeName == 'nn.LstmLayer' then
                param =  layer:parameters()
            elseif typeName == 'nn.BiRnnJoin' then
                param =  layer:parameters()
            elseif typeName == 'cudnn.SpatialMaxPooling' or typeName == 'nn.SpatialMaxPooling' then
                param = {}
            elseif typeName == 'cudnn.ReLU' or typeName == 'nn.ReLU' then
                param = {}
            else
                print(string.format('Unknown class %s', typeName))
                os.exit(0)
            end
            table.insert(state, {typeName, param})
        else
            print(string.format('pass %s', typeName))
        end
    end
    return state
end


function saveModel(model, output_path)
    local state =  modelToList(model)
    torch.save(output_path, state)
end