gemm_batched.h 292 Bytes
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
#pragma once

#include "common.h"
#include "Tensor.h"

Muyang Li's avatar
Muyang Li committed
6
7
8
9
Tensor gemm_batched_fp16(Tensor a,  // FP16 row-major [(... batch ...), M, K]
                         Tensor b,  // FP16 col-major [(... batch ...), N, K]
                         Tensor out // FP32 row-major [(... batch ...), M, N]
);