moe_ops.h 218 Bytes
Newer Older
1
2
#pragma once

3
#include <torch/all.h>
4

5
6
7
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
                  torch::Tensor& token_expert_indices,
                  torch::Tensor& gating_output);