deepseek_v3_weight.cpp 22.2 KB
Newer Older
blkmjsian's avatar
blkmjsian committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
#include "deepseek_v3_impl.hpp"

#include <cmath>

inline std::shared_ptr<Tensor> getInEmbd(
    const DeepSeekV3Meta *meta) {
    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight(nullptr, meta->dt_logits, shape);
}

inline std::shared_ptr<Tensor> getOutNorm(
    const DeepSeekV3Meta *meta) {
    auto shape = std::vector<size_t>({meta->d});
    return Tensor::weight(nullptr, meta->dt_norm, shape);
}

inline std::shared_ptr<Tensor> getOutEmbd(
    const DeepSeekV3Meta *meta) {

    auto shape = std::vector<size_t>({meta->dvoc, meta->d});
    return Tensor::weight(nullptr, meta->dt_logits, shape)
        ->permute({1, 0});
}

inline std::shared_ptr<Tensor> getMLANorm(
    const DeepSeekV3Meta *meta) {
    auto shape = std::vector<size_t>({meta->d});
    return Tensor::weight(nullptr, meta->dt_norm, shape);
}

inline std::shared_ptr<QuantLinearWeight> getQuantLinear(
    const DeepSeekV3Meta *meta, size_t in_dim, size_t out_dim) {
    auto qw = std::make_shared<QuantLinearWeight>();
    auto shape_w = std::vector<size_t>({in_dim, out_dim / 8});
    qw->w = Tensor::weight(nullptr, INFINI_DTYPE_I32, shape_w);
    qw->s = Tensor::weight(nullptr, meta->dt_quant_scale, {in_dim / 64, out_dim});
    qw->z = Tensor::weight(nullptr, INFINI_DTYPE_I32, {in_dim / 64, out_dim / 8});
    return qw;
}

// ------------------- MLA Weights -------------------
inline std::shared_ptr<Tensor> getMLPNorm(
    const DeepSeekV3Meta *meta) {
    auto shape = std::vector<size_t>({meta->d});
    return Tensor::weight(nullptr, meta->dt_norm, shape);
}

inline std::shared_ptr<MLAWeight> getMLA(const DeepSeekV3Meta *meta, int ndev) {
    auto mla = std::make_shared<MLAWeight>();

    mla->q_a_proj = getQuantLinear(meta, meta->d, meta->r_q);
    mla->q_a_norm = Tensor::weight(nullptr, meta->dt_norm, {meta->r_q});
    mla->q_b_proj = getQuantLinear(meta, meta->r_q, meta->nh / ndev * meta->d_qk);

    mla->kv_a_proj = getQuantLinear(meta, meta->d, meta->r_kv + meta->d_rope);
    mla->kv_a_norm = Tensor::weight(nullptr, meta->dt_norm, {meta->r_kv});
    mla->kv_b_proj = getQuantLinear(meta, meta->r_kv, meta->nh / ndev * (meta->d_nope + meta->d_v));

    mla->o_proj = getQuantLinear(meta, meta->nh / ndev * meta->d_v, meta->d);
    return mla;
}

// ------------------- Dense MLP -------------------

inline std::shared_ptr<MLPWeight> getMLP(const DeepSeekV3Meta *meta, size_t d, size_t di) {
    auto mlp = std::make_shared<MLPWeight>();
    mlp->gate = getQuantLinear(meta, d, di);
    mlp->up = getQuantLinear(meta, d, di);
    mlp->down = getQuantLinear(meta, di, d);
    return mlp;
}

inline std::shared_ptr<MLPWeight> getDenseMLP(const DeepSeekV3Meta *meta, int ndev) {
    return getMLP(meta, meta->d, meta->di / ndev);
}

// ------------------- Sparse Route + Experts -------------------

inline std::shared_ptr<GateWeight> getRouteWeight(
    const DeepSeekV3Meta *meta) {
    auto gw = std::make_shared<GateWeight>();
    gw->w = Tensor::weight(nullptr, meta->dt_gate_weight, {meta->nexperts, meta->d})->permute({1, 0});
    gw->b = Tensor::weight(nullptr, meta->dt_gate_bias, {meta->nexperts});
    return gw;
}

inline std::shared_ptr<MLPWeight> getShareExpert(const DeepSeekV3Meta *meta, int ndev) {
    return getMLP(meta, meta->d, meta->di_moe / ndev);
}

inline std::vector<std::shared_ptr<MLPWeight>> getExperts(const DeepSeekV3Meta *meta, int ndev) {
    std::vector<std::shared_ptr<MLPWeight>> experts(meta->nexperts);
    for (size_t i = 0; i < meta->nexperts; i++) {
        experts[i] = getMLP(meta, meta->d, meta->di_moe / ndev);
    }
    return experts;
}

inline std::shared_ptr<Tensor> getSinTable(const DeepSeekV3Meta *meta) {
    auto half_dh = meta->d_rope / 2;
    auto unit = dsize(meta->dt_logits);
    void *table = std::malloc(meta->dctx * half_dh * unit);

    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _sin = std::sin(
                static_cast<float>(i) / std::pow(meta->rope_theta, static_cast<float>(j) / half_dh));
            if (meta->dt_logits == INFINI_DTYPE_F16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
            } else if (meta->dt_logits == INFINI_DTYPE_BF16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin);
            } else if (meta->dt_logits == INFINI_DTYPE_F32) {
                ((float *)table)[i * half_dh + j] = _sin;
            } else {
                std::cout << "unsupported data type" << std::endl;
                exit(1);
            }
        }
    }
    auto shape = std::vector<size_t>({meta->dctx, half_dh});
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

inline std::shared_ptr<Tensor> getCosTable(const DeepSeekV3Meta *meta) {
    auto half_dh = meta->d_rope / 2;
    auto unit = dsize(meta->dt_logits);
    void *table = std::malloc(meta->dctx * half_dh * unit);

    for (size_t i = 0; i < meta->dctx; i++) {
        for (size_t j = 0; j < half_dh; j++) {
            float _cos = std::cos(
                static_cast<float>(i) / std::pow(meta->rope_theta, static_cast<float>(j) / half_dh));
            if (meta->dt_logits == INFINI_DTYPE_F16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
            } else if (meta->dt_logits == INFINI_DTYPE_BF16) {
                ((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos);
            } else if (meta->dt_logits == INFINI_DTYPE_F32) {
                ((float *)table)[i * half_dh + j] = _cos;
            } else {
                std::cout << "unsupported data type" << std::endl;
                exit(1);
            }
        }
    }
    auto shape = std::vector<size_t>({meta->dctx, half_dh});
    auto tensor = Tensor::weight(table, meta->dt_logits, shape);
    std::free(table);
    return tensor;
}

DeepSeekV3Weights::DeepSeekV3Weights(
    const DeepSeekV3Meta *meta, infiniDevice_t device, int ndev, const int *dev_ids) {
    device_weights = std::vector<std::shared_ptr<DeepSeekV3DeviceWeights>>(ndev);
    for (int dev = 0; dev < ndev; dev++) {
        int dev_id = dev_ids[dev];
        RUN_INFINI(infinirtSetDevice(device, dev_id));
        device_weights[dev] = std::make_shared<DeepSeekV3DeviceWeights>();
        device_weights[dev]->device = device;
        device_weights[dev]->dev_id = dev_id;
        RUN_INFINI(infinirtStreamCreate(&device_weights[dev]->load_stream));

        device_weights[dev]->w_in_embd = getInEmbd(meta);
        device_weights[dev]->w_out_norm = getOutNorm(meta);
        device_weights[dev]->w_out_embd = getOutEmbd(meta);
        device_weights[dev]->sin_table = getSinTable(meta);
        device_weights[dev]->cos_table = getCosTable(meta);

        device_weights[dev]->w_layers = std::vector<LayerWeight>(meta->n_dense_layer + meta->n_sparse_layer);

        for (size_t layer = 0; layer < meta->n_dense_layer + meta->n_sparse_layer; layer++) {
            device_weights[dev]->w_layers[layer].mla_norm = getMLANorm(meta);
            device_weights[dev]->w_layers[layer].mla = getMLA(meta, ndev);
            device_weights[dev]->w_layers[layer].mlp_norm = getMLPNorm(meta);
            if (layer < meta->n_dense_layer) {
                device_weights[dev]->w_layers[layer].dense_mlp = getDenseMLP(meta, ndev);
            } else {
                device_weights[dev]->w_layers[layer].route = getRouteWeight(meta);
                device_weights[dev]->w_layers[layer].share_expert = getShareExpert(meta, ndev);
                device_weights[dev]->w_layers[layer].experts = getExperts(meta, ndev);
            }
        }
    }
}

// --- Global
void load_input_embd(DeepSeekV3Weights *weights, void *cpu_ptr) {
    std::cout << "Loading input embedding from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_in_embd->load(cpu_ptr, weight->load_stream);
    }
}

void load_output_norm(DeepSeekV3Weights *weights, void *cpu_ptr) {
    std::cout << "Loading output norm from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_out_norm->load(cpu_ptr, weight->load_stream);
    }
}

void load_output_embd(DeepSeekV3Weights *weights, void *cpu_ptr) {
    std::cout << "Loading output embedding from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_out_embd->load(cpu_ptr, weight->load_stream);
    }
}

// --- Attention
void load_attn_norm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
    std::cout << "Loading attention norm " << layer << " from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].mla_norm->load(cpu_ptr, weight->load_stream);
    }
}

void load_attn_q_a_proj(DeepSeekV3Weights *weights,
                        void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
    std::cout << "Loading attention q_a_proj " << layer << " from " << weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].mla->q_a_proj->w->load(weight_ptr, weight->load_stream);
        weight->w_layers[layer].mla->q_a_proj->s->load(scale_ptr, weight->load_stream);
        weight->w_layers[layer].mla->q_a_proj->z->load(zero_ptr, weight->load_stream);
    }
}

void load_attn_q_a_layernorm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
    std::cout << "Loading attention q_a_layernorm " << layer << " from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].mla->q_a_norm->load(cpu_ptr, weight->load_stream);
    }
}

inline void load_dist_linear(void *w_ptr, void *s_ptr, void *z_ptr, std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> s, std::shared_ptr<Tensor> z, size_t ndev, size_t dev, infinirtStream_t stream) {
    auto w_offset = w->shape()[0] * w->shape()[1] / ndev * dev * dsize(w->dtype());
    auto s_offset = s->shape()[0] * s->shape()[1] / ndev * dev * dsize(s->dtype());
    auto z_offset = z->shape()[0] * z->shape()[1] / ndev * dev * dsize(z->dtype());
    w->load(reinterpret_cast<char *>(w_ptr) + w_offset, stream);
    s->load(reinterpret_cast<char *>(s_ptr) + s_offset, stream);
    z->load(reinterpret_cast<char *>(z_ptr) + z_offset, stream);
}

void load_attn_q_b_proj(DeepSeekV3Weights *weights,
                        void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
    std::cout << "Loading attention q_b_proj " << layer << " from " << weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        auto w = weight->w_layers[layer].mla->q_b_proj->w;
        auto s = weight->w_layers[layer].mla->q_b_proj->s;
        auto z = weight->w_layers[layer].mla->q_b_proj->z;
        load_dist_linear(weight_ptr, scale_ptr, zero_ptr, w, s, z, weights->device_weights.size(), dev, weight->load_stream);
    }
}

void load_attn_kv_a_proj_with_mqa(DeepSeekV3Weights *weights,
                                  void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
    std::cout << "Loading attention kv_a_proj_with_mqa " << layer << " from " << weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].mla->kv_a_proj->w->load(weight_ptr, weight->load_stream);
        weight->w_layers[layer].mla->kv_a_proj->s->load(scale_ptr, weight->load_stream);
        weight->w_layers[layer].mla->kv_a_proj->z->load(zero_ptr, weight->load_stream);
    }
}

void load_attn_kv_a_layernorm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
    std::cout << "Loading attention kv_a_layernorm " << layer << " from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].mla->kv_a_norm->load(cpu_ptr, weight->load_stream);
    }
}

void load_attn_kv_b_proj(DeepSeekV3Weights *weights,
                         void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
    std::cout << "Loading attention kv_b_proj " << layer << " from " << weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        auto w = weight->w_layers[layer].mla->kv_b_proj->w;
        auto s = weight->w_layers[layer].mla->kv_b_proj->s;
        auto z = weight->w_layers[layer].mla->kv_b_proj->z;
        load_dist_linear(weight_ptr, scale_ptr, zero_ptr, w, s, z, weights->device_weights.size(), dev, weight->load_stream);
    }
}

void load_attn_o_proj(DeepSeekV3Weights *weights,
                      void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
    std::cout << "Loading attention o_proj " << layer << " from " << weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        auto w = weight->w_layers[layer].mla->o_proj->w;
        auto s = weight->w_layers[layer].mla->o_proj->s;
        auto z = weight->w_layers[layer].mla->o_proj->z;
        load_dist_linear(weight_ptr, scale_ptr, zero_ptr, w, s, z, weights->device_weights.size(), dev, weight->load_stream);
    }
}

// --- MLP
void load_mlp_norm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
    std::cout << "Loading mlp norm " << layer << " from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].mlp_norm->load(cpu_ptr, weight->load_stream);
    }
}

void load_mlp_dense(DeepSeekV3Weights *weights,
                    void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
                    void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
                    void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
                    size_t layer_id) {
    std::cout << "Loading mlp dense " << layer_id << " from " << gate_weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        auto gate_w = weight->w_layers[layer_id].dense_mlp->gate->w;
        auto gate_s = weight->w_layers[layer_id].dense_mlp->gate->s;
        auto gate_z = weight->w_layers[layer_id].dense_mlp->gate->z;
        auto up_w = weight->w_layers[layer_id].dense_mlp->up->w;
        auto up_s = weight->w_layers[layer_id].dense_mlp->up->s;
        auto up_z = weight->w_layers[layer_id].dense_mlp->up->z;
        auto down_w = weight->w_layers[layer_id].dense_mlp->down->w;
        auto down_s = weight->w_layers[layer_id].dense_mlp->down->s;
        auto down_z = weight->w_layers[layer_id].dense_mlp->down->z;
        load_dist_linear(gate_weight_ptr, gate_scale_ptr, gate_zero_ptr, gate_w, gate_s, gate_z, weights->device_weights.size(), dev, weight->load_stream);
        load_dist_linear(up_weight_ptr, up_scale_ptr, up_zero_ptr, up_w, up_s, up_z, weights->device_weights.size(), dev, weight->load_stream);
        load_dist_linear(down_weight_ptr, down_scale_ptr, down_zero_ptr, down_w, down_s, down_z, weights->device_weights.size(), dev, weight->load_stream);
    }
}

void load_mlp_gate_weight(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
    std::cout << "Loading mlp gate weight " << layer << " from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].route->w->load(cpu_ptr, weight->load_stream);
    }
}

void load_mlp_gate_bias(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
    std::cout << "Loading mlp gate bias " << layer << " from " << cpu_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        weight->w_layers[layer].route->b->load(cpu_ptr, weight->load_stream);
    }
}

void load_mlp_shared_experts(DeepSeekV3Weights *weights,
                             void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
                             void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
                             void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
                             size_t layer_id) {
    std::cout << "Loading mlp shared experts " << layer_id << " from " << gate_weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        auto gate_w = weight->w_layers[layer_id].share_expert->gate->w;
        auto gate_s = weight->w_layers[layer_id].share_expert->gate->s;
        auto gate_z = weight->w_layers[layer_id].share_expert->gate->z;
        auto up_w = weight->w_layers[layer_id].share_expert->up->w;
        auto up_s = weight->w_layers[layer_id].share_expert->up->s;
        auto up_z = weight->w_layers[layer_id].share_expert->up->z;
        auto down_w = weight->w_layers[layer_id].share_expert->down->w;
        auto down_s = weight->w_layers[layer_id].share_expert->down->s;
        auto down_z = weight->w_layers[layer_id].share_expert->down->z;
        load_dist_linear(gate_weight_ptr, gate_scale_ptr, gate_zero_ptr, gate_w, gate_s, gate_z, weights->device_weights.size(), dev, weight->load_stream);
        load_dist_linear(up_weight_ptr, up_scale_ptr, up_zero_ptr, up_w, up_s, up_z, weights->device_weights.size(), dev, weight->load_stream);
        load_dist_linear(down_weight_ptr, down_scale_ptr, down_zero_ptr, down_w, down_s, down_z, weights->device_weights.size(), dev, weight->load_stream);
    }
}

void load_mlp_experts(DeepSeekV3Weights *weights,
                      void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
                      void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
                      void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
                      size_t layer_id, size_t expert_id) {
    std::cout << "Loading mlp expert " << layer_id << " expert " << expert_id
              << " from " << gate_weight_ptr << std::endl;
    for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
        auto weight = weights->device_weights[dev];
        RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
        auto gate_w = weight->w_layers[layer_id].experts[expert_id]->gate->w;
        auto gate_s = weight->w_layers[layer_id].experts[expert_id]->gate->s;
        auto gate_z = weight->w_layers[layer_id].experts[expert_id]->gate->z;
        auto up_w = weight->w_layers[layer_id].experts[expert_id]->up->w;
        auto up_s = weight->w_layers[layer_id].experts[expert_id]->up->s;
        auto up_z = weight->w_layers[layer_id].experts[expert_id]->up->z;
        auto down_w = weight->w_layers[layer_id].experts[expert_id]->down->w;
        auto down_s = weight->w_layers[layer_id].experts[expert_id]->down->s;
        auto down_z = weight->w_layers[layer_id].experts[expert_id]->down->z;
        load_dist_linear(gate_weight_ptr, gate_scale_ptr, gate_zero_ptr, gate_w, gate_s, gate_z, weights->device_weights.size(), dev, weight->load_stream);
        load_dist_linear(up_weight_ptr, up_scale_ptr, up_zero_ptr, up_w, up_s, up_z, weights->device_weights.size(), dev, weight->load_stream);
        load_dist_linear(down_weight_ptr, down_scale_ptr, down_zero_ptr, down_w, down_s, down_z, weights->device_weights.size(), dev, weight->load_stream);
    }
}

static DeepSeekV3WeightLoader weight_loader = {
    // Global
    .load_input_embd = load_input_embd,
    .load_output_norm = load_output_norm,
    .load_output_embd = load_output_embd,
    // Attention
    .load_attn_norm = load_attn_norm,
    .load_attn_q_a_proj = load_attn_q_a_proj,
    .load_attn_q_a_layernorm = load_attn_q_a_layernorm,
    .load_attn_q_b_proj = load_attn_q_b_proj,
    .load_attn_kv_a_proj_with_mqa = load_attn_kv_a_proj_with_mqa,
    .load_attn_kv_a_layernorm = load_attn_kv_a_layernorm,
    .load_attn_kv_b_proj = load_attn_kv_b_proj,
    .load_attn_o_proj = load_attn_o_proj,
    // MLP
    .load_mlp_norm = load_mlp_norm,
    .load_mlp_dense = load_mlp_dense,
    .load_mlp_gate_weight = load_mlp_gate_weight,
    .load_mlp_gate_bias = load_mlp_gate_bias,
    .load_mlp_shared_experts = load_mlp_shared_experts,
    .load_mlp_experts = load_mlp_experts,
};

439
__INFINI_C DeepSeekV3Weights *
blkmjsian's avatar
blkmjsian committed
440
441
442
443
444
445
446
447
createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
                        infiniDevice_t device,
                        int ndev,
                        const int *dev_ids) {
    auto weights = new DeepSeekV3Weights(meta, device, ndev, dev_ids);
    return weights;
};

448
__INFINI_C DeepSeekV3WeightLoader *
blkmjsian's avatar
blkmjsian committed
449
450
451
createDeepSeekV3WeightLoader() {
    return &weight_loader;
}