ops.cuh 7.74 KB
Newer Older
1
2
3
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
Tim Dettmers's avatar
Tim Dettmers committed
4
5
6
7
8
// LICENSE file in the root directory of this source tree.

#ifndef ops_H
#define ops_H

9
#include <assert.h>
10
#include <cstdint>
Tim Dettmers's avatar
Tim Dettmers committed
11
#include <iostream>
12
#include <stdio.h>
Tim Dettmers's avatar
Tim Dettmers committed
13

Tim Dettmers's avatar
Tim Dettmers committed
14
#include <cublasLt.h>
15
16
17
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
Tim Dettmers's avatar
Tim Dettmers committed
18
19
#include <cusparse.h>
#include <functional>
20
#include <vector>
Tim Dettmers's avatar
Tim Dettmers committed
21

22
23
24
25
26
27
28
29
#define CUDA_CHECK_RETURN(value)                                                                                       \
    {                                                                                                                  \
        cudaError_t _m_cudaStat = value;                                                                               \
        if (_m_cudaStat != cudaSuccess) {                                                                              \
            fprintf(stderr, "Error %s at line %d in file %s\n", cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__);  \
            exit(1);                                                                                                   \
        }                                                                                                              \
    }
Tim Dettmers's avatar
Tim Dettmers committed
30

31
32
33
34
35
36
37
38
39
40
#define CHECK_CUSPARSE(value)                                                                                          \
    {                                                                                                                  \
        cusparseStatus_t _m_cudaStat = value;                                                                          \
        if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) {                                                                  \
            fprintf(                                                                                                   \
                stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__    \
            );                                                                                                         \
            exit(1);                                                                                                   \
        }                                                                                                              \
    }
Tim Dettmers's avatar
Tim Dettmers committed
41
42
43
44
45
46
47
48
49
50
51

inline void checkCudaStatus(cudaError_t status) {
    if (status != cudaSuccess) {
        printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
        throw std::logic_error("cuda API failed");
    }
}

inline int checkCublasStatus(cublasStatus_t status) {
    if (status != CUBLAS_STATUS_SUCCESS) {
        printf("cuBLAS API failed with status %d\n", status);
52
        // throw std::logic_error("cuBLAS API failed");
Tim Dettmers's avatar
Tim Dettmers committed
53
54
55
56
57
        return 1;
    }
    return 0;
}

58
59
typedef enum Operations_t {
    ksmul = 0,
Tim Dettmers's avatar
Tim Dettmers committed
60
61
} Operations_t;

62
63
64
65
66
67
68
69
typedef enum Optimizer_t {
    ADAM = 0,
    MOMENTUM = 1,
    RMSPROP = 2,
    LARS = 3,
    ADAGRAD = 4,
    LION = 5,
    ADEMAMIX = 6
Tim Dettmers's avatar
Tim Dettmers committed
70
71
} Optimizer_t;

72
73
74
75
76
77
typedef enum Transform_t {
    ROW = 0,
    COL = 1,
    COL32 = 2,
    COL_TURING = 3,
    COL_AMPERE = 4,
Tim Dettmers's avatar
Tim Dettmers committed
78
79
} Transform_t;

80
81
82
83
typedef enum DataType_t {
    General8bit = 0,
    FP4 = 1,
    NF4 = 2,
Tim Dettmers's avatar
Tim Dettmers committed
84
85
} DataType_t;

86
87
88
89
typedef enum Funcs_t {
    FILL = 0,
    ARANGE = 1,
    _MUL = 2,
Tim Dettmers's avatar
Tim Dettmers committed
90
91
} Funcs_t;

92
93
94
class Context {
  public:
    cublasHandle_t m_handle;
Tim Dettmers's avatar
Tim Dettmers committed
95

96
97
98
99
100
    Context() {
        cublasHandle_t handle;
        cublasCreate_v2(&handle);
        m_handle = handle;
    }
Tim Dettmers's avatar
Tim Dettmers committed
101
102
};

103
104
105
class ContextLt {
  public:
    cublasLtHandle_t m_handle;
Tim Dettmers's avatar
Tim Dettmers committed
106

107
108
109
110
111
    ContextLt() {
        cublasLtHandle_t handle;
        cublasLtCreate(&handle);
        m_handle = handle;
    }
Tim Dettmers's avatar
Tim Dettmers committed
112
113
};

114
115
116
class ContextCusparse {
  public:
    cusparseHandle_t m_handle;
Tim Dettmers's avatar
Tim Dettmers committed
117

118
119
120
121
122
    ContextCusparse() {
        cusparseHandle_t handle;
        cusparseCreate(&handle);
        m_handle = handle;
    }
Tim Dettmers's avatar
Tim Dettmers committed
123
124
};

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
void quantize(float* code, float* A, unsigned char* out, int n);
void dequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE>
void quantizeBlockwise(
    float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
    float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, cudaStream_t stream
);

template <typename T, int OPTIMIZER>
void optimizer32bit(
    T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,
    float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale,
    bool skip_zeros, int n
);

template <typename T, int OPTIMIZER>
void optimizerStatic8bit(
    T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
    float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
    float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
);

template <typename T, int OPTIMIZER>
void optimizerStatic8bitBlockwise(
    T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
    float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
    float weight_decay, const float gnorm_scale, bool skip_zeros, int n
);

template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n);

void gemmex(
    Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
    int ldb, int ldc
);
void strided_gemmex(
    Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
    int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
);

template <int DTYPE_OUT, int SCALE_ROWS>
int igemmlt(
    cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
    int lda, int ldb, int ldc, cudaStream_t stream
);

void cutlass_igemm(
    bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc
);
void dequant_mm_int32_fp16(
    int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, cudaStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(
    half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, cudaStream_t stream
);

void spmm_coo(
    cusparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
    int ldb, half* B, int ldc, half* C, bool transposed_B
);

template <typename T, int BITS>
void spmm_coo_very_sparse_naive(
    int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
    float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);

void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);

template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
    int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
    int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
    int blocksize, cudaStream_t stream
);

template <typename T, int FUNC> void func(T* A, T* B, T value, long n);
Tim Dettmers's avatar
Tim Dettmers committed
210

Tim Dettmers's avatar
Tim Dettmers committed
211
#endif