cublasAlgoMap.cc 7.7 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "cublasAlgoMap.h"

lvhan028's avatar
lvhan028 committed
19
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

cublasAlgoMap::cublasAlgoMap(const std::string filename, const std::string sp_config_filename):
    config_filename_(filename), sp_config_filename_(sp_config_filename)
{
    loadGemmConfig();
    loadSpGemmConfig();
}

cublasAlgoMap::cublasAlgoMap(const cublasAlgoMap& algo_map):
    config_filename_(algo_map.config_filename_),
    sp_config_filename_(algo_map.sp_config_filename_),
    algo_map_(algo_map.algo_map_),
    sp_algo_map_(algo_map.sp_algo_map_)
{
}

cublasAlgoMap::~cublasAlgoMap()
{
    algo_map_.clear();
}

void cublasAlgoMap::loadGemmConfig()
{
    FILE* fd;
    fd = fopen(config_filename_.c_str(), "r");
    if (fd == NULL) {
        std::cout << "[WARNING] " << config_filename_ << " is not found; using default GEMM algo" << std::endl;
        return;
    }

    int   batchCount2, m2, n2, k2, algoId, customOption, tile, splitK_val;
    int   batch_size, seq_len, head_num, size_per_head, dataType;
    int   swizzle, reductionScheme, workspaceSize, stages;
    int   inner_shapeId, cluster_shapeId, mma_shapeId, cga_shapeId, sche_mode;
    float exec_time;
    char  tmp[1024];
    if (!fgets(tmp, 1024, fd)) {
        printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
        exit(-1);
    }
    while (fscanf(fd,
                  "%d %d %d %d %d ### %d %d %d %d %d %d %d %d %d %d %d %d "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
                  "%d %d "
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
                  "%d %d %d "
#endif
                  "%f\n",
                  &batch_size,
                  &seq_len,
                  &head_num,
                  &size_per_head,
                  &dataType,
                  &batchCount2,
                  &n2,
                  &m2,
                  &k2,
                  &algoId,
                  &customOption,
                  &tile,
                  &splitK_val,
                  &swizzle,
                  &reductionScheme,
                  &workspaceSize,
                  &stages,
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
                  &inner_shapeId,
                  &cluster_shapeId,
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
                  &mma_shapeId,
                  &cga_shapeId,
                  &sche_mode,
#endif
                  &exec_time)
           != EOF) {
        if (dataType != FLOAT_DATATYPE && dataType != HALF_DATATYPE && dataType != BFLOAT16_DATATYPE
            && dataType != INT8_DATATYPE && dataType != FP8_DATATYPE) {
            printf("[WARNING][readAlgoFromConfig] wrong dataType %d!\n", dataType);
            continue;
        }
        cublasAlgoConfig_t markStr{batchCount2, m2, n2, k2, static_cast<CublasDataType>(dataType)};
        // workspaceSize should be zero
        if (algo_map_.find(markStr) == algo_map_.end()) {
            algo_map_[markStr].algoId          = algoId;
            algo_map_[markStr].customOption    = customOption;
            algo_map_[markStr].tile            = tile;
            algo_map_[markStr].splitK_val      = splitK_val;
            algo_map_[markStr].swizzle         = swizzle;
            algo_map_[markStr].reductionScheme = reductionScheme;
            algo_map_[markStr].workspaceSize   = workspaceSize;
            algo_map_[markStr].stages          = stages;
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
            algo_map_[markStr].inner_shapeId   = (uint16_t)inner_shapeId;
            algo_map_[markStr].cluster_shapeId = (uint16_t)cluster_shapeId;
#elif (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH < 3)
            algo_map_[markStr].mma_shapeId = (uint16_t)mma_shapeId;
            algo_map_[markStr].cga_shapeId = (uint16_t)cga_shapeId;
            algo_map_[markStr].sche_mode   = (uint16_t)sche_mode;
#endif
            algo_map_[markStr].exec_time = exec_time;
        }
    }
    fclose(fd);
}

bool cublasAlgoMap::isExist(
    const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
{
    cublasAlgoConfig_t mark{batch_count, n, m, k, data_type};
    return algo_map_.find(mark) != algo_map_.end();
}

cublasLtMatmulAlgo_info
cublasAlgoMap::getAlgo(const int batch_count, const int m, const int n, const int k, const CublasDataType data_type)
{
    cublasAlgoConfig_t mark{batch_count, n, m, k, data_type};
    if (algo_map_.find(mark) != algo_map_.end()) {
        return algo_map_[mark];
    }
    else {
        cublasLtMatmulAlgo_info tmp_algo;
        tmp_algo.algoId =
            static_cast<int>(data_type == FLOAT_DATATYPE ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP);
        tmp_algo.customOption    = -1;
        tmp_algo.tile            = -1;
        tmp_algo.splitK_val      = -1;
        tmp_algo.swizzle         = -1;
        tmp_algo.reductionScheme = -1;
        tmp_algo.workspaceSize   = -1;
        tmp_algo.stages          = -1;
        tmp_algo.exec_time       = -1.0f;
        return tmp_algo;
    }
}

void cublasAlgoMap::loadSpGemmConfig()
{
    if (sp_config_filename_.empty()) {
        return;
    }
    FILE* fd = fopen(sp_config_filename_.c_str(), "r");
    if (fd == NULL) {
        printf("[WARNING] %s is not found; using SPGEMM algo id 0\n", sp_config_filename_.c_str());
        return;
    }
    sp_algo_map_.clear();
    int   batch_size, seq_len, head_num, size_per_head, data_type;
    int   batchCount, m, n, k, algoId;
    float exec_time;
    char  tmp[1024];
    if (!fgets(tmp, 1024, fd)) {
        printf("[ERROR] fgets fail at %s:%d \n", __FILE__, __LINE__);
        exit(-1);
    }
    while (fscanf(fd,
                  "%d %d %d %d %d ### %d %d %d %d %d %f\n",
                  &batch_size,
                  &seq_len,
                  &head_num,
                  &size_per_head,
                  &data_type,
                  &batchCount,
                  &m,
                  &n,
                  &k,
                  &algoId,
                  &exec_time)
           != EOF) {
        char mark[256];
        sprintf(mark, "%d_%d_%d_%d", batchCount, m, n, k);
        std::string markStr(mark);
        sp_algo_map_[markStr] = algoId;
    }
    fclose(fd);
}

int cublasAlgoMap::getSpAlgo(const int batch_count, const int m, const int n, const int k)
{
    char mark[256];
    sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
    if (sp_algo_map_.find(mark) != sp_algo_map_.end()) {
        return sp_algo_map_[mark];
    }
    else {
        // for remove padding, select algo 1 for simplicity
        return 0;
    }
}

bool cublasAlgoMap::isUseSparse(const int batch_count, const int m, const int n, const int k)
{
    // not available to use cusparselt.
    if (m % 8 != 0 || n % 8 != 0 || k % 8 != 0) {
        return false;
    }
    char mark[256];
    sprintf(mark, "%d_%d_%d_%d", batch_count, m, n, k);
    if (sp_algo_map_.find(mark) != sp_algo_map_.end()) {
        return sp_algo_map_[mark] != -1;
    }
    else {
        // no gemm test case, choose sparse according to sparse flag
        return true;
    }
}

lvhan028's avatar
lvhan028 committed
226
}  // namespace turbomind