misc_kernels.h 698 Bytes
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
#pragma once

#include "common.h"
#include "Tensor.h"

muyangli's avatar
muyangli committed
6
7
namespace nunchaku::kernels {

Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
Tensor add(Tensor a, Tensor b);
void mul_add(Tensor x, Tensor scale, Tensor bias);
muyangli's avatar
muyangli committed
10
11
void mul_add_batch(Tensor x, Tensor scale, bool batch_scale, double scale_shift, Tensor bias, bool batch_bias);

Zhekai Zhang's avatar
Zhekai Zhang committed
12
13
14
15
16
17
Tensor embedding(Tensor input_id, Tensor lookup);
Tensor argmax_sample(Tensor logits);
void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
Tensor quant_static(Tensor x, float scale);
Tensor quant_static_fuse_gelu(Tensor x, float scale);

18
19
void cast(Tensor input, Tensor output);

Zhekai Zhang's avatar
Zhekai Zhang committed
20
21
22
Tensor topk(Tensor x, int k);

template<size_t N>
muyangli's avatar
muyangli committed
23
24
std::array<Tensor, N> split_mod(Tensor input);

Muyang Li's avatar
Muyang Li committed
25
}; // namespace nunchaku::kernels