Unverified Commit 99caa491 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] add bfloat16 support for gptq marlin kernel (#4788)

parent 5c342570
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace gptq_marlin {
template <typename scalar_t>
class ScalarType {
};
template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
static __device__ float inline num2float(const half x) { return __half2float(x); }
static __device__ half2 inline num2num2(const half x) { return __half2half2(x); }
static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); }
static __host__ __device__ half inline float2num(const float x) { return __float2half(x); }
};
template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); }
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); }
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); }
static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); }
#endif
};
}
#endif
......@@ -14,6 +14,7 @@ import pytest
import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
from .utils import check_logprobs_close
......@@ -52,7 +53,7 @@ MODELS = [
@pytest.mark.skipif(gptq_marlin_not_supported,
reason="gptq_marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
......@@ -76,11 +77,15 @@ def test_models(
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
del gptq_marlin_model
_ROPE_DICT.clear() # clear rope cache to avoid rope dtype error
# Run gptq.
# The naive gptq kernel doesn't support bf16 yet.
# Here we always compare fp16/bf16 gpt marlin kernel
# to fp16 gptq kernel.
gptq_model = vllm_runner(model_name=model_name,
revision=revision,
dtype=dtype,
dtype="half",
quantization="gptq",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1)
......
......@@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
......@@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size = input_size
# Validate dtype
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
if params_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(f"The params dtype must be float16 "
f"or bfloat16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment