balancing.cu 1.33 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
#include "balancing.cuh"
#include <torch/extension.h>

/* 
 * note that due to limit of cuda atomic operator, capacity should be int32
 */
7
torch::Tensor _limit_by_capacity(
Rick Ho's avatar
Rick Ho committed
8
9
        torch::Tensor expert_count, torch::Tensor capacity,
        long n_expert, long n_worker) {
Rick Ho's avatar
Rick Ho committed
10
11
    CHECK_INPUT(expert_count);
    CHECK_INPUT(capacity);
Rick Ho's avatar
Rick Ho committed
12
13
    auto expert_count_ack = torch::empty_like(expert_count);
    auto smgr = getCudaStreamManager(expert_count.device().index());
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
    fmoe_cuda_limit_by_capacity_impl(
            expert_count.data_ptr<long>(),
            capacity.data_ptr<int>(),
            expert_count_ack.data_ptr<long>(),
            n_expert, n_worker, smgr);
19
    return expert_count_ack;
Rick Ho's avatar
Rick Ho committed
20
21
}

22
torch::Tensor _prune_gate_by_capacity(
Rick Ho's avatar
Rick Ho committed
23
24
25
26
        torch::Tensor gate_idx, torch::Tensor expert_count,
        long n_expert, long n_worker) {
    auto smgr = getCudaStreamManager(expert_count.device().index());
    auto batch_size = gate_idx.numel();
27
28
29
30
    auto opt = torch::TensorOptions()
        .dtype(gate_idx.dtype())
        .device(gate_idx.device());
    auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
Rick Ho's avatar
Rick Ho committed
31
32
    fmoe_cuda_prune_gate_by_capacity_impl(
            gate_idx.data_ptr<long>(),
33
            new_gate_idx.data_ptr<long>(),
Rick Ho's avatar
Rick Ho committed
34
35
            expert_count.data_ptr<int>(),
            batch_size, n_expert, n_worker, smgr);
36
    return new_gate_idx;
Rick Ho's avatar
Rick Ho committed
37
}