grouped_gemm_fp16.cpp 2.05 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#include <iostream>
5
6

#include "profiler/include/profile_grouped_gemm_impl.hpp"
zjing14's avatar
zjing14 committed
7
8
9
10
11
12
13
14

namespace {

using ADataType   = ck::half_t;
using BDataType   = ck::half_t;
using CDataType   = ck::half_t;
using AccDataType = float;

15
16
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
zjing14's avatar
zjing14 committed
17

18
19
template <typename ALayout, typename BLayout, typename CLayout>
bool TestGroupedGemm()
zjing14's avatar
zjing14 committed
20
{
zjing14's avatar
zjing14 committed
21
    int group_count = rand() % 10 + 1;
zjing14's avatar
zjing14 committed
22
23

    // GEMM shape
24
    std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
zjing14's avatar
zjing14 committed
25
26
27
    std::vector<const void*> p_a, p_b;
    std::vector<void*> p_c;

28
    std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideCs;
zjing14's avatar
zjing14 committed
29
30
31

    for(int i = 0; i < group_count; i++)
    {
32
33
34
        Ms.push_back(256 + 256 * (rand() % 10));
        Ns.push_back(256 + 256 * (rand() % 10));
        Ks.push_back(128 + 128 * (rand() % 10));
zjing14's avatar
zjing14 committed
35

36
37
38
        StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
        StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
        StrideCs.push_back(std::is_same<Row, CLayout>::value ? Ns[i] : Ms[i]);
zjing14's avatar
zjing14 committed
39
40
    }

41
42
43
44
45
46
47
48
    return ck::profiler::profile_grouped_gemm_impl<ADataType,
                                                   BDataType,
                                                   CDataType,
                                                   AccDataType,
                                                   ALayout,
                                                   BLayout,
                                                   CLayout>(
        true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
zjing14's avatar
zjing14 committed
49
50
51
52
53
54
55
56
}

} // anonymous namespace

int main()
{
    bool res = true;

57
58
59
60
    res = res && TestGroupedGemm<Row, Row, Row>();
    res = res && TestGroupedGemm<Row, Col, Row>();
    res = res && TestGroupedGemm<Col, Row, Row>();
    res = res && TestGroupedGemm<Col, Col, Row>();
zjing14's avatar
zjing14 committed
61
62

    std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
zjing14's avatar
zjing14 committed
63
64

    return res ? 0 : 1;
zjing14's avatar
zjing14 committed
65
}