batched_gemm_int8.cpp 6.65 KB
Newer Older
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4
5

#include <iostream>

6
#include "profiler/profile_batched_gemm_impl.hpp"
7

8
9
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"

10
11
12
13
14
15
16
namespace {
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
17
18

using PassThrough = ck::tensor_operation::element_wise::PassThrough;
19
20
21
22
23
24
25
26
27
28
29
} // namespace

int main()
{
    int M          = 256;
    int N          = 256;
    int K          = 128;
    int BatchCount = 3;

    bool pass = true;

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    using namespace ck::tensor_operation::device;

    pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
                                                           BDataType,
                                                           CDataType,
                                                           Row,
                                                           Row,
                                                           Row,
                                                           PassThrough,
                                                           PassThrough,
                                                           PassThrough,
                                                           DeviceBatchedGemm<Row,
                                                                             Row,
                                                                             Row,
                                                                             ADataType,
                                                                             BDataType,
                                                                             CDataType,
                                                                             PassThrough,
                                                                             PassThrough,
                                                                             PassThrough>>(
                       true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
                                                           BDataType,
                                                           CDataType,
                                                           Row,
                                                           Col,
                                                           Row,
                                                           PassThrough,
                                                           PassThrough,
                                                           PassThrough,
                                                           DeviceBatchedGemm<Row,
                                                                             Col,
                                                                             Row,
                                                                             ADataType,
                                                                             BDataType,
                                                                             CDataType,
                                                                             PassThrough,
                                                                             PassThrough,
                                                                             PassThrough>>(
                       true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
                                                           BDataType,
                                                           CDataType,
                                                           Col,
                                                           Row,
                                                           Row,
                                                           PassThrough,
                                                           PassThrough,
                                                           PassThrough,
                                                           DeviceBatchedGemm<Col,
                                                                             Row,
                                                                             Row,
                                                                             ADataType,
                                                                             BDataType,
                                                                             CDataType,
                                                                             PassThrough,
                                                                             PassThrough,
                                                                             PassThrough>>(
                       true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    pass = pass && ck::profiler::profile_batched_gemm_impl<ADataType,
                                                           BDataType,
                                                           CDataType,
                                                           Col,
                                                           Col,
                                                           Row,
                                                           PassThrough,
                                                           PassThrough,
                                                           PassThrough,
                                                           DeviceBatchedGemm<Col,
                                                                             Col,
                                                                             Row,
                                                                             ADataType,
                                                                             BDataType,
                                                                             CDataType,
                                                                             PassThrough,
                                                                             PassThrough,
                                                                             PassThrough>>(
                       true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
111
112
113
114

    std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl;
    return pass ? 0 : 1;
}