batched_gemm_fp16.cpp 1.45 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

4
#include <iostream>
5

Chao Liu's avatar
Chao Liu committed
6
#include "profiler/include/profile_batched_gemm_impl.hpp"
7

8
namespace {
9
10
11
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
12

13
14
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
15
16
17
18
} // namespace

int main()
{
19
20
21
22
    int M          = 512;
    int N          = 256;
    int K          = 128;
    int BatchCount = 3;
23

24
    bool pass = true;
25

26
27
    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
28
               true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
29
30
31

    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
32
               true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
33
34
35

    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
36
               true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
37

38
39
    pass = pass &&
           ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
40
               true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
41

42
    std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
43
    return pass ? 0 : 1;
44
}