Module.h 3.58 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once

#include "common.h"
#include "Tensor.h"
#include "debug.h"

class Module {
protected:
    enum class ParamFlags : int {
        None = 0,
        Optional = 1,
    };
    struct Param {
        Tensor *tensor;
        ParamFlags flags;
    };

    friend inline ParamFlags operator|(ParamFlags lhs, ParamFlags rhs) {
        return static_cast<ParamFlags>(static_cast<int>(lhs) | static_cast<int>(rhs));
    }
    friend inline ParamFlags operator&(ParamFlags lhs, ParamFlags rhs) {
        return static_cast<ParamFlags>(static_cast<int>(lhs) & static_cast<int>(rhs));
    }

public:
    std::string getFullName() const {
        if (!parent) {
            return name;
        }
        std::string fullName = parent->getFullName();
        if (fullName.empty()) {
            return name;
        } else {
            return fullName + "." + name;
        }
    }

    void traverse(std::function<void(Module *)> func) {
        func(this);
        for (Module *c : this->children) {
            c->traverse(func);
        }
    }

    virtual void loadParams(TensorsProvider &provider, bool partial = false) {
        for (Module *c : children) {
            c->loadParams(provider, partial);
        }
        std::string fullName = getFullName();
        std::string prefix = fullName.empty() ? "" : fullName + ".";
        for (auto &&[key, param] : params) {
            Tensor src = provider.getTensor(prefix + key);
            if (!src.valid()) {
                if (partial || int(param.flags & ParamFlags::Optional)) {
                    continue;
                }
                throw std::runtime_error(spdlog::fmt_lib::format("Tensor {} not found", prefix + key));
            }
            this->loadParam(key, *param.tensor, src);
            // tensor->copy_(src);
        }
    }

    void setName(std::string name) {
        assert(!parent);
        this->name = std::move(name);
    }



protected:
    virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
        dst.copy_(src);
    }

    struct ChildrenRegisterHelper {
        ChildrenRegisterHelper(Module &self) : self(self) {}
        Module &self;
        ChildrenRegisterHelper operator()(Module &module, std::string name) {
            return self.registerChildren(module, name);
        }
    };
    ChildrenRegisterHelper registerChildren(Module &module, std::string name) {
        module.parent = this;
        module.name = name;
        children.push_back(&module);
        return ChildrenRegisterHelper(*this);
    }

    struct ParamsRegisterHelper {
        ParamsRegisterHelper(Module &self) : self(self) {}
        Module &self;
        ParamsRegisterHelper operator()(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
            return self.registerParams(param, name, flags);
        }
    };
    ParamsRegisterHelper registerParams(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
        if (param.valid()) {
            params[name].tensor = &param;
            params[name].flags = flags;
        }
        return ParamsRegisterHelper(*this);
    }

    void debug(std::string name, Tensor tensor) {
        if (DebugContext::ctxs.empty()) {
            return;
        }
        std::string prefix = getFullName();
        if (!prefix.empty()) {
            prefix += ".";
        }
        tensor = tensor.copy(Device::cpu());
        for (auto &&ctx : DebugContext::ctxs) {
            ctx->tensors[prefix + name] = tensor;
        }
    }

public:
    Module *parent = nullptr;
    std::string name = "";
    std::vector<Module *> children;
    std::map<std::string, Param> params;
};