pos_encoding.h 194 Bytes
Newer Older
Haotian Tang's avatar
Haotian Tang committed
1
2
3
#pragma once
#include <torch/extension.h>

4
void rotary_embedding_neox(
Haotian Tang's avatar
Haotian Tang committed
5
6
7
8
  torch::Tensor& positions,
  torch::Tensor& query,
  torch::Tensor& key,
  int head_size,
9
  torch::Tensor& cos_sin_cache);