nvfp4_quant_kernels.cu 6.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <torch/all.h>

#include <cuda_runtime_api.h>
#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cuda_fp8.h>
26
#include "dispatch_utils.h"
27
28

#include "cuda_utils.h"
29
#include "launch_bounds_utils.h"
30
#include "nvfp4_utils.cuh"
31

32
namespace vllm {
33

34
35
36
37
38
39
40
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
  static_assert(std::is_integral_v<Int>,
                "round_up argument must be integral type");
  return (x + y - 1) / y * y;
}

41
42
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
43
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
44
45
    cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
                    float const* SFScale, uint32_t* out, uint32_t* SFout) {
46
47
48
49
50
51
  using PackedVec = PackedVec<Type>;
  static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
      (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
                "Vec size is not matched.");

52
53
54
55
56
57
58
59
60
61
62
  int sf_m = round_up<int>(numRows, 128);
  int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
  int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
  for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
    // Each thread writes 4 uint32_t elements.
    for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
         col += blockDim.x * 4) {
      SFout[row * sf_n_int + col] = 0x00;
    }
  }

63
64
65
  // Get the global scaling factor, which will be applied to the SF.
  // Note SFScale is the same as next GEMM's alpha, which is
  // (448.f / (Alpha_A / 6.f)).
66
  float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

  // Input tensor row/col loops.
  for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
    for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
         colIdx += blockDim.x) {
      int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
      PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
      // Get the output tensor offset.
      // Same as inOffset because 8 elements are packed into one uint32_t.
      int64_t outOffset = inOffset;
      auto& out_pos = out[outOffset];

      auto sf_out =
          cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
                                             CVT_FP4_NUM_THREADS_PER_SF>(
              rowIdx, colIdx, numCols, SFout);

      out_pos =
85
          cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
86
87
88
89
90
91
92
93
94
95
96
    }
  }
}

template <typename T>
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
                           int64_t* output, int32_t* SFOuput, bool useUE8M0,
                           int multiProcessorCount, cudaStream_t stream) {
  // Grid, Block size.
  // Each thread converts 8 values.
  dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
97
98
99
  // Get number of blocks per SM
  int const numBlocksPerSM =
      vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
  dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));

  // Launch the cvt kernel.
  if (useUE8M0) {
    cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
        m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
        reinterpret_cast<uint32_t*>(SFOuput));
  } else {
    cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
        m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
        reinterpret_cast<uint32_t*>(SFOuput));
  }
}

// Instantiate the function.
template void invokeFP4Quantization(int m, int n, half const* input,
                                    float const* SFScale, int64_t* output,
                                    int32_t* SFOuput, bool useUE8M0,
                                    int multiProcessorCount,
                                    cudaStream_t stream);

template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
                                    float const* SFScale, int64_t* output,
                                    int32_t* SFOuput, bool useUE8M0,
                                    int multiProcessorCount,
                                    cudaStream_t stream);

127
128
}  // namespace vllm

129
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
130
131
132
133
134
135
136
                             torch::Tensor const& input,
                             torch::Tensor const& output_sf,
                             torch::Tensor const& input_sf) {
  int32_t m = input.size(0);
  int32_t n = input.size(1);

  TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
137
138
139
  TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
                  input.scalar_type() == at::ScalarType::BFloat16,
              "Unsupported input data type for quantize_to_fp4.");
140
141
142
143
144
145
146

  int multiProcessorCount =
      get_device_attribute(cudaDevAttrMultiProcessorCount, -1);

  auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
  auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
  auto output_ptr = static_cast<int64_t*>(output.data_ptr());
147
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
148
  auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
149
150
151
152

  // We don't support e8m0 scales at this moment.
  bool useUE8M0 = false;

153
154
155
156
157
158
  VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
    using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
    auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
    vllm::invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr,
                                sf_out, useUE8M0, multiProcessorCount, stream);
  });
159
}