misc_kernels.h 518 Bytes
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
#pragma once

#include "common.h"
#include "Tensor.h"

Tensor add(Tensor a, Tensor b);
void mul_add(Tensor x, Tensor scale, Tensor bias);
Tensor embedding(Tensor input_id, Tensor lookup);
Tensor argmax_sample(Tensor logits);
void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
Tensor quant_static(Tensor x, float scale);
Tensor quant_static_fuse_gelu(Tensor x, float scale);

14
15
void cast(Tensor input, Tensor output);

Zhekai Zhang's avatar
Zhekai Zhang committed
16
17
18
19
Tensor topk(Tensor x, int k);

template<size_t N>
std::array<Tensor, N> split_mod(Tensor input);