/* * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cuda_fp8_utils.h" namespace fastertransformer { #ifdef ENABLE_FP8 template __global__ void quantizeMatrix(T_OUT* output, float const* input_scale, T_IN const* input, uint32_t size, uint32_t n) { for (uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += blockDim.x * gridDim.x) { if (quantize_mode == QUANTIZE_MODE::PER_CHANNEL) { output[i] = T_OUT((float)(input[i]) * __ldg(input_scale + (i % n))); } else { output[i] = T_OUT((float)(input[i]) * __ldg(input_scale)); } } } template void invokeQuantizeMatrix( T_OUT* output, float const* input_scale, T_IN const* input, uint32_t size, uint32_t n, cudaStream_t stream) { dim3 grid(32); dim3 block(256); quantizeMatrix<<>>(output, input_scale, input, size, n); } #define defineinvokeQuantizeMatrix(type_out, type_in, mode) \ template void invokeQuantizeMatrix(type_out * output, \ float const* input_scale, \ type_in const* input, \ uint32_t size, \ uint32_t n, \ cudaStream_t stream); defineinvokeQuantizeMatrix(__nv_fp8_e4m3, float, QUANTIZE_MODE::PER_CHANNEL); defineinvokeQuantizeMatrix(__nv_fp8_e4m3, float, QUANTIZE_MODE::PER_TENSOR); defineinvokeQuantizeMatrix(__nv_fp8_e4m3, half, QUANTIZE_MODE::PER_CHANNEL); defineinvokeQuantizeMatrix(__nv_fp8_e4m3, half, QUANTIZE_MODE::PER_TENSOR); defineinvokeQuantizeMatrix(half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_CHANNEL); defineinvokeQuantizeMatrix(half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR); defineinvokeQuantizeMatrix(float, __nv_fp8_e4m3, QUANTIZE_MODE::PER_CHANNEL); defineinvokeQuantizeMatrix(float, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR); #ifdef ENABLE_BF16 defineinvokeQuantizeMatrix(__nv_fp8_e4m3, __nv_bfloat16, QUANTIZE_MODE::PER_CHANNEL); defineinvokeQuantizeMatrix(__nv_fp8_e4m3, __nv_bfloat16, QUANTIZE_MODE::PER_TENSOR); defineinvokeQuantizeMatrix(__nv_bfloat16, __nv_fp8_e4m3, QUANTIZE_MODE::PER_CHANNEL); defineinvokeQuantizeMatrix(__nv_bfloat16, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR); #endif template __global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int size) { for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) { T_FAKE tmp = (T_FAKE)((float)src[tid]); dst[tid] = (T_OUT)((float)tmp); } } template void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int size, cudaStream_t stream) { fakeQuantize<<<256, 256, 0, stream>>>(dst, src, size); } template void invokeFakeQuantize(float* dst, const float* src, const int size, cudaStream_t stream); template void invokeFakeQuantize(half* dst, const half* src, const int size, cudaStream_t stream); template void invokeFakeQuantize<__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3>(__nv_bfloat16* dst, const __nv_bfloat16* src, const int size, cudaStream_t stream); template __global__ void computeFP8QuantizeScale(float* quant_ptr, const T_W* weights, const int k, const int n) { float max = -10000.f; for (int i = 0; i < k; i++) { float val = fabs((float)weights[i * n + blockIdx.x * blockDim.x + threadIdx.x]); max = max > val ? max : val; if (threadIdx.x == 0 && blockIdx.x == 0 && i % 100 == 0) { printf("max: %f, val: %f \n", max, val); } } // quant_ptr[blockIdx.x * blockDim.x + threadIdx.x] = 1.0f; // quant_ptr[blockIdx.x * blockDim.x + threadIdx.x] = FP8_E4M3_MAX / max; quant_ptr[blockIdx.x * blockDim.x + threadIdx.x] = std::max(max / FP8_E4M3_MAX, 1.0f / 32.f); } template void invokeComputeFP8QuantizeScale(float* quant_ptr, const T_W* weights, const int k, const int n, cudaStream_t stream) { dim3 block(256); dim3 grid; grid.x = (n + 255) / 256; computeFP8QuantizeScale<<>>(quant_ptr, weights, k, n); } #ifdef ENABLE_BF16 template void invokeComputeFP8QuantizeScale( float* quant_ptr, const __nv_bfloat16* weights, const int k, const int n, cudaStream_t stream); #endif template void invokeComputeFP8QuantizeScale(float* quant_ptr, const float* weights, const int k, const int n, cudaStream_t stream); #endif // ENABLE_FP8 } // namespace fastertransformer