Module.h 10 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
#pragma once

#include "common.h"
#include "Tensor.h"
#include "debug.h"

class Module {
protected:
    enum class ParamFlags : int {
Muyang Li's avatar
Muyang Li committed
10
        None     = 0,
Zhekai Zhang's avatar
Zhekai Zhang committed
11
        Optional = 1,
muyangli's avatar
muyangli committed
12
13
14
15
16
17
18
19
        LazyLoad = 2,
    };
    struct TensorLazyLoadInfo {
        TensorShape shape;
        Tensor::ScalarType type;
        Device device;

        Tensor src;
Zhekai Zhang's avatar
Zhekai Zhang committed
20
21
    };
    struct Param {
Muyang Li's avatar
Muyang Li committed
22
        Tensor *tensor   = nullptr;
muyangli's avatar
muyangli committed
23
24
25
        ParamFlags flags = ParamFlags::None;

        TensorLazyLoadInfo lazyInfo;
Zhekai Zhang's avatar
Zhekai Zhang committed
26
27
28
29
30
31
32
33
    };

    friend inline ParamFlags operator|(ParamFlags lhs, ParamFlags rhs) {
        return static_cast<ParamFlags>(static_cast<int>(lhs) | static_cast<int>(rhs));
    }
    friend inline ParamFlags operator&(ParamFlags lhs, ParamFlags rhs) {
        return static_cast<ParamFlags>(static_cast<int>(lhs) & static_cast<int>(rhs));
    }
muyangli's avatar
muyangli committed
34
35
36
    static bool checkFlag(ParamFlags flags, ParamFlags target) {
        return int(flags & target);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50

public:
    std::string getFullName() const {
        if (!parent) {
            return name;
        }
        std::string fullName = parent->getFullName();
        if (fullName.empty()) {
            return name;
        } else {
            return fullName + "." + name;
        }
    }

muyangli's avatar
muyangli committed
51
52
    std::string getPrefix() const {
        std::string fullName = getFullName();
Muyang Li's avatar
Muyang Li committed
53
        std::string prefix   = fullName.empty() ? "" : fullName + ".";
muyangli's avatar
muyangli committed
54
55
56
        return prefix;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
57
58
59
60
61
62
63
64
65
66
67
    void traverse(std::function<void(Module *)> func) {
        func(this);
        for (Module *c : this->children) {
            c->traverse(func);
        }
    }

    virtual void loadParams(TensorsProvider &provider, bool partial = false) {
        for (Module *c : children) {
            c->loadParams(provider, partial);
        }
muyangli's avatar
muyangli committed
68
        std::string prefix = getPrefix();
Zhekai Zhang's avatar
Zhekai Zhang committed
69
70
71
72
73
74
75
76
        for (auto &&[key, param] : params) {
            Tensor src = provider.getTensor(prefix + key);
            if (!src.valid()) {
                if (partial || int(param.flags & ParamFlags::Optional)) {
                    continue;
                }
                throw std::runtime_error(spdlog::fmt_lib::format("Tensor {} not found", prefix + key));
            }
muyangli's avatar
muyangli committed
77
78
79
80
81
82
            if (enabledLazyLoad && checkFlag(param.flags, ParamFlags::LazyLoad)) {
                param.lazyInfo.src = src;
                if (!param.tensor->valid()) {
                    continue;
                }
                // keep loading params if param is not released
Muyang Li's avatar
Muyang Li committed
83
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
84
85
86
87
88
89
90
91
92
93
            this->loadParam(key, *param.tensor, src);
            // tensor->copy_(src);
        }
    }

    void setName(std::string name) {
        assert(!parent);
        this->name = std::move(name);
    }

muyangli's avatar
muyangli committed
94
95
96
97
98
99
    void loadLazyParams() {
        traverse([](Module *m) {
            for (auto &&[key, param] : m->params) {
                if (!checkFlag(param.flags, ParamFlags::LazyLoad)) {
                    continue;
                }
Zhekai Zhang's avatar
Zhekai Zhang committed
100

muyangli's avatar
muyangli committed
101
                TensorLazyLoadInfo &lazy = param.lazyInfo;
Muyang Li's avatar
Muyang Li committed
102
103
                Tensor &dst              = *param.tensor;
                Tensor src               = lazy.src;
muyangli's avatar
muyangli committed
104
105
106
107
108
109
110

                if (dst.valid()) {
                    continue;
                }
                dst = Tensor::allocate(lazy.shape, lazy.type, lazy.device);

                if (!src.valid() && !checkFlag(param.flags, ParamFlags::Optional)) {
Muyang Li's avatar
Muyang Li committed
111
112
                    throw std::runtime_error(
                        spdlog::fmt_lib::format("Lazy load: Tensor {} has no src", m->getPrefix() + key));
muyangli's avatar
muyangli committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                }
                m->loadParam(key, dst, src);
            }
        });
    }
    void releaseLazyParams() {
        traverse([](Module *m) {
            if (!m->enabledLazyLoad) {
                return;
            }
            for (auto &&[key, param] : m->params) {
                if (checkFlag(param.flags, ParamFlags::LazyLoad)) {
                    *param.tensor = Tensor{};
                }
            }
        });
    }
    void setLazyLoad(bool val) {
Muyang Li's avatar
Muyang Li committed
131
        traverse([val](Module *m) { m->enabledLazyLoad = val; });
muyangli's avatar
muyangli committed
132
    }
133
    void setAutoCastFP16(bool val) {
Muyang Li's avatar
Muyang Li committed
134
        traverse([val](Module *m) { m->enabledAutoCastFP16 = val; });
135
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
136
137
138

protected:
    virtual void loadParam(std::string key, Tensor &dst, Tensor src) {
139
140
141
142
        static const std::set<Tensor::ScalarType> whitelist = {
            Tensor::FP16,
            Tensor::BF16,
        };
Muyang Li's avatar
Muyang Li committed
143
144
        if (enabledAutoCastFP16 && dst.scalar_type() != src.scalar_type() && whitelist.contains(dst.scalar_type()) &&
            whitelist.contains(src.scalar_type())) {
145
146
147
148
            copyWithCast(dst, src);
        } else {
            dst.copy_(src);
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
149
150
151
152
153
154
155
156
157
158
159
    }

    struct ChildrenRegisterHelper {
        ChildrenRegisterHelper(Module &self) : self(self) {}
        Module &self;
        ChildrenRegisterHelper operator()(Module &module, std::string name) {
            return self.registerChildren(module, name);
        }
    };
    ChildrenRegisterHelper registerChildren(Module &module, std::string name) {
        module.parent = this;
Muyang Li's avatar
Muyang Li committed
160
        module.name   = name;
Zhekai Zhang's avatar
Zhekai Zhang committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        children.push_back(&module);
        return ChildrenRegisterHelper(*this);
    }

    struct ParamsRegisterHelper {
        ParamsRegisterHelper(Module &self) : self(self) {}
        Module &self;
        ParamsRegisterHelper operator()(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
            return self.registerParams(param, name, flags);
        }
    };
    ParamsRegisterHelper registerParams(Tensor &param, std::string name, ParamFlags flags = ParamFlags::None) {
        if (param.valid()) {
            params[name].tensor = &param;
Muyang Li's avatar
Muyang Li committed
175
            params[name].flags  = flags;
muyangli's avatar
muyangli committed
176
177
178

            if (checkFlag(flags, ParamFlags::LazyLoad) && param.valid()) {
                TensorLazyLoadInfo &lazy = params[name].lazyInfo;
Muyang Li's avatar
Muyang Li committed
179
180
181
                lazy.shape               = param.shape;
                lazy.type                = param.dtype();
                lazy.device              = param.device();
muyangli's avatar
muyangli committed
182
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
183
184
185
186
187
        }
        return ParamsRegisterHelper(*this);
    }

    void debug(std::string name, Tensor tensor) {
188
        if (DebugContext::ctxs.empty() || !tensor.valid()) {
Zhekai Zhang's avatar
Zhekai Zhang committed
189
190
191
192
193
194
195
196
197
198
199
200
            return;
        }
        std::string prefix = getFullName();
        if (!prefix.empty()) {
            prefix += ".";
        }
        tensor = tensor.copy(Device::cpu());
        for (auto &&ctx : DebugContext::ctxs) {
            ctx->tensors[prefix + name] = tensor;
        }
    }

201
202
203
private:
    void copyWithCast(Tensor dst, Tensor src);

Zhekai Zhang's avatar
Zhekai Zhang committed
204
public:
Muyang Li's avatar
Muyang Li committed
205
    Module *parent   = nullptr;
Zhekai Zhang's avatar
Zhekai Zhang committed
206
207
208
    std::string name = "";
    std::vector<Module *> children;
    std::map<std::string, Param> params;
muyangli's avatar
muyangli committed
209

Muyang Li's avatar
Muyang Li committed
210
    bool enabledLazyLoad     = false;
211
    bool enabledAutoCastFP16 = true;
Zhekai Zhang's avatar
Zhekai Zhang committed
212
};
muyangli's avatar
muyangli committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226

struct LayerOffloadHelper {
    using func_t = std::function<void(int)>;

    const bool offload;
    const int numLayers;

    func_t funcCompute, funcLoad, funcUnload;

    std::unique_ptr<CUDAStreamWrapper> streamCompute;
    std::unique_ptr<CUDAStreamWrapper> streamLoad;
    std::unique_ptr<CUDAEventWrapper> eventComputeDone;
    std::unique_ptr<CUDAEventWrapper> eventLoadDone;

Muyang Li's avatar
Muyang Li committed
227
228
    LayerOffloadHelper(bool offload, int numLayers, func_t funcCompute, func_t funcLoad, func_t funcUnload)
        : offload(offload), numLayers(numLayers), funcCompute(funcCompute), funcLoad(funcLoad), funcUnload(funcUnload) {
muyangli's avatar
muyangli committed
229
230
        if (offload) {
            streamCompute = std::make_unique<CUDAStreamWrapper>();
Muyang Li's avatar
Muyang Li committed
231
            streamLoad    = std::make_unique<CUDAStreamWrapper>();
Zhekai Zhang's avatar
Zhekai Zhang committed
232
233
234
235
236

            needWorkaround = checkWorkaround();
            if (needWorkaround) {
                spdlog::debug("Offloading helper: use WDDM workaround");
            }
muyangli's avatar
muyangli committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        }
    }

    void run() {
        for (int i = 0; i < numLayers; i++) {
            run(i);
        }
        waitEvent(eventComputeDone.get());
        funcUnload(numLayers - 1);
    }

private:
    void run(int layer) {
        if (!offload) {
            funcCompute(layer);
        } else {
            std::unique_ptr<CUDAEventWrapper> nextComputeDone, nextLoadDone;

            // issue compute kernels first so that we could still overlap compute and memcpy if memory is not pinned
            {
                CUDAStreamContext ctx(streamCompute->stream);
                waitEvent(eventLoadDone.get());
                funcCompute(layer);
                nextComputeDone = std::make_unique<CUDAEventWrapper>();
fengzch-das's avatar
fengzch-das committed
261
                checkCUDA(cudaEventRecord(nextComputeDone->event, getCurrentCUDAStream()));
Zhekai Zhang's avatar
Zhekai Zhang committed
262
                workaroundFlush();
muyangli's avatar
muyangli committed
263
264
265
266
267
268
269
270
271
272
273
274
            }

            {
                CUDAStreamContext ctx(streamLoad->stream);
                waitEvent(eventComputeDone.get());
                if (layer - 1 > 0) {
                    funcUnload(layer - 1);
                }
                if (layer + 1 < numLayers) {
                    funcLoad(layer + 1);
                }
                nextLoadDone = std::make_unique<CUDAEventWrapper>();
fengzch-das's avatar
fengzch-das committed
275
                checkCUDA(cudaEventRecord(nextLoadDone->event, getCurrentCUDAStream()));
Zhekai Zhang's avatar
Zhekai Zhang committed
276
                workaroundFlush();
muyangli's avatar
muyangli committed
277
278
279
            }

            eventComputeDone = std::move(nextComputeDone);
Muyang Li's avatar
Muyang Li committed
280
            eventLoadDone    = std::move(nextLoadDone);
Zhekai Zhang's avatar
Zhekai Zhang committed
281
282

            workaroundSynchronize();
muyangli's avatar
muyangli committed
283
284
285
286
287
288
289
        }
    }

    static void waitEvent(CUDAEventWrapper *event) {
        if (!event) {
            return;
        }
fengzch-das's avatar
fengzch-das committed
290
        checkCUDA(cudaStreamWaitEvent(getCurrentCUDAStream(), event->event));
muyangli's avatar
muyangli committed
291
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
292
293
294
295
296
297
298
299
300
301
302
303

    // WDDM prevents multiple streams run concurrently
    // use flush and synchronize to work around
    bool needWorkaround;
    static bool checkWorkaround() {
        if (char *env = getenv("NUNCHAKU_OFFLOAD_WDDM_WORKAROUND")) {
            if (std::string(env) == "1") {
                return true;
            } else if (std::string(env) == "0") {
                return false;
            }
        }
Muyang Li's avatar
Muyang Li committed
304
305

#ifdef _WIN32
Zhekai Zhang's avatar
Zhekai Zhang committed
306
        return true;
Muyang Li's avatar
Muyang Li committed
307
#else
Zhekai Zhang's avatar
Zhekai Zhang committed
308
        return false;
Muyang Li's avatar
Muyang Li committed
309
#endif
Zhekai Zhang's avatar
Zhekai Zhang committed
310
311
312
313
314
    }
    void workaroundFlush() {
        if (!needWorkaround) {
            return;
        }
fengzch-das's avatar
fengzch-das committed
315
        cudaStreamQuery(getCurrentCUDAStream());
Zhekai Zhang's avatar
Zhekai Zhang committed
316
317
318
319
320
    }
    void workaroundSynchronize() {
        if (!needWorkaround) {
            return;
        }
fengzch-das's avatar
fengzch-das committed
321
        checkCUDA(cudaEventSynchronize(eventComputeDone->event));
Zhekai Zhang's avatar
Zhekai Zhang committed
322
    }
Muyang Li's avatar
Muyang Li committed
323
};