pos_encoding.h 205 Bytes
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
#pragma once
#include <torch/extension.h>

Casper Hansen's avatar
Casper Hansen committed
4
void rotary_embedding(
Haotian Tang's avatar
Haotian Tang committed
5
6
7
8
  torch::Tensor& positions,
  torch::Tensor& query,
  torch::Tensor& key,
  int head_size,
Casper Hansen's avatar
Casper Hansen committed
9
10
  torch::Tensor& cos_sin_cache,
  bool is_neox);