batched_gemm_fp16.cpp 1.26 KB
Newer Older
1
#include <iostream>
2

Chao Liu's avatar
Chao Liu committed
3
#include "profiler/include/profile_batched_gemm_impl.hpp"
4

5
namespace {
6
7
8
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
9

10
11
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
12
13
14
15
} // namespace

int main()
{
16
17
18
19
    int M          = 512;
    int N          = 256;
    int K          = 128;
    int BatchCount = 3;
20

21
    bool pass = true;
22

23
24
25
26
27
28
29
30
31
32
33
    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
               true, 1, false, 1, M, N, K, K, N, N, BatchCount);

    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
               true, 1, false, 1, M, N, K, K, K, N, BatchCount);

    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
               true, 1, false, 1, M, N, K, M, N, N, BatchCount);
34

35
36
37
    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
               true, 1, false, 1, M, N, K, M, K, N, BatchCount);
38

39
    std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
40
    return pass ? 0 : 1;
41
}