gemm_batched.h 257 Bytes
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
#pragma once

#include "common.h"
#include "Tensor.h"

Tensor gemm_batched_fp16(
    Tensor a,   // FP16 row-major [(... batch ...), M, K]
    Tensor b,   // FP16 col-major [(... batch ...), N, K]
    Tensor out  // FP32 row-major [(... batch ...), M, N]
);