export.cpp 582 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include <torch/extension.h>

#include <optional>

using torch::Tensor;

namespace at {
namespace native {

void rms_rotary_embedding_fuse(
    Tensor& positions, Tensor& query, Tensor& key, int64_t head_size,
    Tensor& cos_sin_cache, bool is_neox, Tensor weight_q, Tensor weight_k,
    std::optional<Tensor> residual_q, std::optional<Tensor> residual_k,
    double epsilon);

}  // namespace native
}  // namespace at

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("rms_rotary_embedding_fuse", &at::native::rms_rotary_embedding_fuse,
        "rms_rotary_embedding_fuse");
}