cuda_fp8_utils.cu 5.88 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "cuda_fp8_utils.h"

lvhan028's avatar
lvhan028 committed
19
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#ifdef ENABLE_FP8

template<typename T_OUT, typename T_IN, QUANTIZE_MODE quantize_mode>
__global__ void quantizeMatrix(T_OUT* output, float const* input_scale, T_IN const* input, uint32_t size, uint32_t n)
{
    for (uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += blockDim.x * gridDim.x) {
        if (quantize_mode == QUANTIZE_MODE::PER_CHANNEL) {
            output[i] = T_OUT((float)(input[i]) * __ldg(input_scale + (i % n)));
        }
        else {
            output[i] = T_OUT((float)(input[i]) * __ldg(input_scale));
        }
    }
}

template<typename T_OUT, typename T_IN, QUANTIZE_MODE quantize_mode>
void invokeQuantizeMatrix(
    T_OUT* output, float const* input_scale, T_IN const* input, uint32_t size, uint32_t n, cudaStream_t stream)
{
    dim3 grid(32);
    dim3 block(256);
    quantizeMatrix<T_OUT, T_IN, quantize_mode><<<grid, block, 0, stream>>>(output, input_scale, input, size, n);
}

#define defineinvokeQuantizeMatrix(type_out, type_in, mode)                                                            \
    template void invokeQuantizeMatrix<type_out, type_in, mode>(type_out * output,                                     \
                                                                float const*   input_scale,                            \
                                                                type_in const* input,                                  \
                                                                uint32_t       size,                                   \
                                                                uint32_t       n,                                      \
                                                                cudaStream_t   stream);

defineinvokeQuantizeMatrix(__nv_fp8_e4m3, float, QUANTIZE_MODE::PER_CHANNEL);
defineinvokeQuantizeMatrix(__nv_fp8_e4m3, float, QUANTIZE_MODE::PER_TENSOR);
defineinvokeQuantizeMatrix(__nv_fp8_e4m3, half, QUANTIZE_MODE::PER_CHANNEL);
defineinvokeQuantizeMatrix(__nv_fp8_e4m3, half, QUANTIZE_MODE::PER_TENSOR);
defineinvokeQuantizeMatrix(half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_CHANNEL);
defineinvokeQuantizeMatrix(half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR);
defineinvokeQuantizeMatrix(float, __nv_fp8_e4m3, QUANTIZE_MODE::PER_CHANNEL);
defineinvokeQuantizeMatrix(float, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR);
#ifdef ENABLE_BF16
defineinvokeQuantizeMatrix(__nv_fp8_e4m3, __nv_bfloat16, QUANTIZE_MODE::PER_CHANNEL);
defineinvokeQuantizeMatrix(__nv_fp8_e4m3, __nv_bfloat16, QUANTIZE_MODE::PER_TENSOR);
defineinvokeQuantizeMatrix(__nv_bfloat16, __nv_fp8_e4m3, QUANTIZE_MODE::PER_CHANNEL);
defineinvokeQuantizeMatrix(__nv_bfloat16, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR);
#endif

template<typename T_OUT, typename T_IN, typename T_FAKE>
__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int size)
{
    for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) {
        T_FAKE tmp = (T_FAKE)((float)src[tid]);
        dst[tid]   = (T_OUT)((float)tmp);
    }
}

template<typename T_OUT, typename T_IN, typename T_FAKE>
void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int size, cudaStream_t stream)
{
    fakeQuantize<T_OUT, T_IN, T_FAKE><<<256, 256, 0, stream>>>(dst, src, size);
}

template void
invokeFakeQuantize<float, float, __nv_fp8_e4m3>(float* dst, const float* src, const int size, cudaStream_t stream);
template void
invokeFakeQuantize<half, half, __nv_fp8_e4m3>(half* dst, const half* src, const int size, cudaStream_t stream);
template void invokeFakeQuantize<__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3>(__nv_bfloat16*       dst,
                                                                              const __nv_bfloat16* src,
                                                                              const int            size,
                                                                              cudaStream_t         stream);

template<typename T_W>
__global__ void computeFP8QuantizeScale(float* quant_ptr, const T_W* weights, const int k, const int n)
{
    float max = -10000.f;
    for (int i = 0; i < k; i++) {
        float val = fabs((float)weights[i * n + blockIdx.x * blockDim.x + threadIdx.x]);
        max       = max > val ? max : val;
        if (threadIdx.x == 0 && blockIdx.x == 0 && i % 100 == 0) {
            printf("max: %f, val: %f \n", max, val);
        }
    }
    // quant_ptr[blockIdx.x * blockDim.x + threadIdx.x] = 1.0f;
    // quant_ptr[blockIdx.x * blockDim.x + threadIdx.x] = FP8_E4M3_MAX / max;
    quant_ptr[blockIdx.x * blockDim.x + threadIdx.x] = std::max(max / FP8_E4M3_MAX, 1.0f / 32.f);
}

template<typename T_W>
void invokeComputeFP8QuantizeScale(float* quant_ptr, const T_W* weights, const int k, const int n, cudaStream_t stream)
{
    dim3 block(256);
    dim3 grid;
    grid.x = (n + 255) / 256;
    computeFP8QuantizeScale<T_W><<<grid, block, 0, stream>>>(quant_ptr, weights, k, n);
}

#ifdef ENABLE_BF16
template void invokeComputeFP8QuantizeScale(
    float* quant_ptr, const __nv_bfloat16* weights, const int k, const int n, cudaStream_t stream);
#endif
template void
invokeComputeFP8QuantizeScale(float* quant_ptr, const float* weights, const int k, const int n, cudaStream_t stream);

#endif  // ENABLE_FP8
lvhan028's avatar
lvhan028 committed
124
}  // namespace turbomind