"vscode:/vscode.git/clone" did not exist on "289f08af6af01e7d89644ed334435333fdd769b3"
grouped_gemm_fp16.cpp 2.13 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
#include <iostream>
5
#include <random>
6

7
#include "profiler/profile_grouped_gemm_impl.hpp"
zjing14's avatar
zjing14 committed
8
9
10
11
12
13
14
15

namespace {

using ADataType   = ck::half_t;
using BDataType   = ck::half_t;
using CDataType   = ck::half_t;
using AccDataType = float;

16
17
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
zjing14's avatar
zjing14 committed
18

19
20
template <typename ALayout, typename BLayout, typename CLayout>
bool TestGroupedGemm()
zjing14's avatar
zjing14 committed
21
{
22
23
24
25

    std::mt19937 gen(19391);
    std::uniform_int_distribution<> distrib(1, 10);
    int group_count = distrib(gen);
zjing14's avatar
zjing14 committed
26
27

    // GEMM shape
28
    std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
zjing14's avatar
zjing14 committed
29
30
31
    std::vector<const void*> p_a, p_b;
    std::vector<void*> p_c;

32
    std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideCs;
zjing14's avatar
zjing14 committed
33
34
35

    for(int i = 0; i < group_count; i++)
    {
36
37
38
        Ms.push_back(256 + 256 * distrib(gen));
        Ns.push_back(256 + 256 * distrib(gen));
        Ks.push_back(128 + 128 * distrib(gen));
zjing14's avatar
zjing14 committed
39

40
41
42
        StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
        StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
        StrideCs.push_back(std::is_same<Row, CLayout>::value ? Ns[i] : Ms[i]);
zjing14's avatar
zjing14 committed
43
44
    }

45
46
47
48
49
50
51
52
    return ck::profiler::profile_grouped_gemm_impl<ADataType,
                                                   BDataType,
                                                   CDataType,
                                                   AccDataType,
                                                   ALayout,
                                                   BLayout,
                                                   CLayout>(
        true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
zjing14's avatar
zjing14 committed
53
54
55
56
57
58
59
60
}

} // anonymous namespace

int main()
{
    bool res = true;

61
62
63
64
    res = res && TestGroupedGemm<Row, Row, Row>();
    res = res && TestGroupedGemm<Row, Col, Row>();
    res = res && TestGroupedGemm<Col, Row, Row>();
    res = res && TestGroupedGemm<Col, Col, Row>();
zjing14's avatar
zjing14 committed
65
66

    std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
zjing14's avatar
zjing14 committed
67
68

    return res ? 0 : 1;
zjing14's avatar
zjing14 committed
69
}