gemm_xdl_fp8.cpp 9.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include "common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"

using ADataType        = ck::f8_t;
using BDataType        = ck::f8_t;
10
using CDataType        = ck::f8_t;
11
using AccDataType      = float;
zjing14's avatar
zjing14 committed
12
using CShuffleDataType = float;
13

14
15
using F8      = ck::f8_t;
using F32     = float;
16
using ALayout = Row;
17
using BLayout = Row;
18
19
20
21
22
23
24
25
using CLayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
using DeviceGemmFactory = std::tuple<
#if 1
    ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<Row,
                                                          Row,
                                                          Row,
                                                          F8,
                                                          F8,
                                                          F8,
                                                          F32,
                                                          F8,
                                                          PassThrough,
                                                          PassThrough,
                                                          PassThrough,
                                                          GemmDefault,
                                                          1,
                                                          256,
                                                          256,
                                                          128,
                                                          64,
                                                          16,
                                                          4,
                                                          32,
                                                          32,
                                                          4,
                                                          2,
                                                          S<4, 64, 1>,
                                                          S<1, 0, 2>,
                                                          S<1, 0, 2>,
                                                          2,
                                                          16,
                                                          16,
                                                          1,
                                                          S<8, 32, 1>,
                                                          S<0, 2, 1>,
                                                          S<0, 2, 1>,
                                                          1,
                                                          4,
                                                          4,
                                                          0,
                                                          1,
                                                          1,
                                                          S<1, 64, 1, 4>,
                                                          16,
                                                          ck::LoopScheduler::Interwave,
                                                          ck::PipelineVersion::v1>,
    ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<Row,
                                                          Row,
                                                          Row,
                                                          F8,
                                                          F8,
                                                          F8,
                                                          F32,
                                                          F8,
                                                          PassThrough,
                                                          PassThrough,
                                                          PassThrough,
                                                          GemmDefault,
                                                          1,
                                                          256,
                                                          256,
                                                          128,
                                                          64,
                                                          16,
                                                          16,
                                                          32,
                                                          32,
                                                          4,
                                                          2,
                                                          S<4, 64, 1>,
                                                          S<1, 0, 2>,
                                                          S<1, 0, 2>,
                                                          2,
                                                          16,
                                                          16,
                                                          1,
                                                          S<4, 64, 1>,
                                                          S<0, 2, 1>,
                                                          S<0, 2, 1>,
                                                          1,
                                                          2,
                                                          16,
                                                          1,
                                                          1,
                                                          1,
                                                          S<1, 64, 1, 4>,
                                                          16,
                                                          ck::LoopScheduler::Interwave,
                                                          ck::PipelineVersion::v1>,
#endif
    ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle<Row,
                                                          Row,
                                                          Row,
                                                          F8,
                                                          F8,
                                                          F8,
                                                          F32,
                                                          F8,
                                                          PassThrough,
                                                          PassThrough,
                                                          PassThrough,
                                                          GemmDefault,
                                                          1,
                                                          256,
                                                          256,
                                                          128,
                                                          64,
                                                          16,
                                                          8,
                                                          32,
                                                          32,
                                                          4,
                                                          2,
                                                          S<4, 64, 1>,
                                                          S<1, 0, 2>,
                                                          S<1, 0, 2>,
                                                          2,
                                                          16,
                                                          16,
                                                          1,
                                                          S<8, 32, 1>,
                                                          S<0, 2, 1>,
                                                          S<0, 2, 1>,
                                                          1,
                                                          4,
                                                          8,
                                                          1,
                                                          1,
                                                          1,
                                                          S<1, 64, 1, 4>,
                                                          16,
                                                          ck::LoopScheduler::Interwave,
                                                          ck::PipelineVersion::v1>>;
158
159
160
161
162
163
using ReferenceGemmInstance = ck::tensor_operation::host::
    ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }