Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99caa491
Unverified
Commit
99caa491
authored
May 16, 2024
by
Jinzhen Lin
Committed by
GitHub
May 16, 2024
Browse files
[Kernel] add bfloat16 support for gptq marlin kernel (#4788)
parent
5c342570
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
246 additions
and
73 deletions
+246
-73
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+173
-67
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
+62
-0
tests/models/test_gptq_marlin.py
tests/models/test_gptq_marlin.py
+7
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+4
-4
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
99caa491
This diff is collapsed.
Click to expand it.
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
0 → 100644
View file @
99caa491
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
class
ScalarType
{
};
template
<
>
class
ScalarType
<
half
>
{
public:
using
scalar_t
=
half
;
using
scalar_t2
=
half2
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
static
__device__
float
inline
num2float
(
const
half
x
)
{
return
__half2float
(
x
);
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
static
__device__
half2
inline
nums2num2
(
const
half
x1
,
const
half
x2
)
{
return
__halves2half2
(
x1
,
x2
);
}
static
__host__
__device__
half
inline
float2num
(
const
float
x
)
{
return
__float2half
(
x
);
}
};
template
<
>
class
ScalarType
<
nv_bfloat16
>
{
public:
using
scalar_t
=
nv_bfloat16
;
using
scalar_t2
=
nv_bfloat162
;
using
FragA
=
Vec
<
nv_bfloat162
,
4
>
;
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
#endif
};
}
#endif
tests/models/test_gptq_marlin.py
View file @
99caa491
...
@@ -14,6 +14,7 @@ import pytest
...
@@ -14,6 +14,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.rotary_embedding
import
_ROPE_DICT
from
.utils
import
check_logprobs_close
from
.utils
import
check_logprobs_close
...
@@ -52,7 +53,7 @@ MODELS = [
...
@@ -52,7 +53,7 @@ MODELS = [
@
pytest
.
mark
.
skipif
(
gptq_marlin_not_supported
,
@
pytest
.
mark
.
skipif
(
gptq_marlin_not_supported
,
reason
=
"gptq_marlin is not supported on this GPU type."
)
reason
=
"gptq_marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
def
test_models
(
...
@@ -76,11 +77,15 @@ def test_models(
...
@@ -76,11 +77,15 @@ def test_models(
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
example_prompts
[:
-
1
],
max_tokens
,
num_logprobs
)
example_prompts
[:
-
1
],
max_tokens
,
num_logprobs
)
del
gptq_marlin_model
del
gptq_marlin_model
_ROPE_DICT
.
clear
()
# clear rope cache to avoid rope dtype error
# Run gptq.
# Run gptq.
# The naive gptq kernel doesn't support bf16 yet.
# Here we always compare fp16/bf16 gpt marlin kernel
# to fp16 gptq kernel.
gptq_model
=
vllm_runner
(
model_name
=
model_name
,
gptq_model
=
vllm_runner
(
model_name
=
model_name
,
revision
=
revision
,
revision
=
revision
,
dtype
=
dtype
,
dtype
=
"half"
,
quantization
=
"gptq"
,
quantization
=
"gptq"
,
max_model_len
=
MAX_MODEL_LEN
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
1
)
tensor_parallel_size
=
1
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
99caa491
...
@@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size
=
input_size
group_size
=
input_size
# Validate dtype
# Validate dtype
if
params_dtype
!=
torch
.
float16
:
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]
:
raise
ValueError
(
raise
ValueError
(
f
"The params dtype must be float16 "
f
"The params dtype must be
float16, but got
{
params_dtype
}
"
)
f
"or b
float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment