pos_encoding.h 194 Bytes
Newer Older
Casper's avatar
Casper committed
1
2
3
4
5
6
7
8
9
#pragma once
#include <torch/extension.h>

void rotary_embedding_neox(
  torch::Tensor& positions,
  torch::Tensor& query,
  torch::Tensor& key,
  int head_size,
  torch::Tensor& cos_sin_cache);