Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47f0954a
Unverified
Commit
47f0954a
authored
Jul 03, 2024
by
Michael Goin
Committed by
GitHub
Jul 03, 2024
Browse files
[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)
parent
7cd2ebb0
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1587 additions
and
44 deletions
+1587
-44
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+1308
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
docs/source/quantization/fp8.rst
docs/source/quantization/fp8.rst
+2
-1
docs/source/quantization/supported_hardware.rst
docs/source/quantization/supported_hardware.rst
+1
-1
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+84
-4
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+14
-5
vllm/_custom_ops.py
vllm/_custom_ops.py
+9
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+134
-30
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+25
-3
No files found.
CMakeLists.txt
View file @
47f0954a
...
...
@@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
...
...
csrc/ops.h
View file @
47f0954a
...
...
@@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/fp8/fp8_marlin.cu
0 → 100644
View file @
47f0954a
This diff is collapsed.
Click to expand it.
csrc/torch_bindings.cpp
View file @
47f0954a
...
...
@@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops
.
def
(
...
...
docs/source/quantization/fp8.rst
View file @
47f0954a
...
...
@@ -4,7 +4,8 @@ FP8
==================
vLLM
supports
FP8
(
8
-
bit
floating
point
)
weight
and
activation
quantization
using
hardware
acceleration
on
GPUs
such
as
Nvidia
H100
and
AMD
MI300x
.
Currently
,
only
Hopper
and
Ada
Lovelace
GPUs
are
supported
.
Currently
,
only
Hopper
and
Ada
Lovelace
GPUs
are
officially
supported
for
W8A8
.
Ampere
GPUs
are
supported
for
W8A16
(
weight
-
only
FP8
)
utilizing
Marlin
kernels
.
Quantization
of
models
with
FP8
allows
for
a
2
x
reduction
in
model
memory
requirements
and
up
to
a
1.6
x
improvement
in
throughput
with
minimal
impact
on
accuracy
.
Please
visit
the
HF
collection
of
`
quantized
FP8
checkpoints
of
popular
LLMs
ready
to
use
with
vLLM
<
https
://
huggingface
.
co
/
collections
/
neuralmagic
/
fp8
-
llms
-
for
-
vllm
-
666742
ed2b78b7ac8df13127
>`
_
.
...
...
docs/source/quantization/supported_hardware.rst
View file @
47f0954a
...
...
@@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌
❌
✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌
✅
✅ ✅ ❌ ❌ ❌ ❌ ❌
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
...
...
tests/kernels/test_marlin_gemm.py
View file @
47f0954a
...
...
@@ -8,7 +8,8 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
)
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_SUPPORTED_NUM_BITS
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
...
...
@@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
MarlinWorkspace
,
compute_max_diff
,
is_marlin_supported
,
marlin_24_quantize
,
marlin_quantize
,
marlin_weights
)
marlin_quantize
,
marlin_weights
,
pack_fp8_to_int32
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
quantize_weights
,
sort_weights
)
...
...
@@ -38,9 +39,11 @@ MNK_FACTORS = [
(
67
,
13
,
11
),
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
rand_data
(
shape
):
return
torch
.
randn
(
shape
,
dtype
=
torch
.
half
,
device
=
"cuda"
)
def
rand_data
(
shape
,
dtype
=
torch
.
float16
):
return
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
...
...
@@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_marlin_supported
(),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_fp8_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
dtype
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
# WEIGHTS
fp8_weight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
b_weight
,
scale
=
None
)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
fp8_weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
b_scales
=
marlin_scales
,
workspace
=
workspace
.
scratch
,
num_bits
=
num_bits
,
size_m
=
a_input
.
shape
[
0
],
size_n
=
b_weight
.
shape
[
1
],
size_k
=
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
a_input
,
b_weight
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
tests/quantization/test_fp8.py
View file @
47f0954a
...
...
@@ -6,7 +6,7 @@ import pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
.
_custom_ops
import
scaled_fp8_quant
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
MODELS
=
[
...
...
@@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
>=
89
:
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
...
...
@@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
x
=
(
torch
.
randn
(
size
=
(
11
,
11
),
device
=
"cuda"
)
*
13
).
to
(
dtype
)
# Dynamic quantization
ref_y
,
inv_scale
=
scaled_fp8_quant
(
x
,
None
)
ref_y
,
inv_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
)
ref_y
=
per_tensor_dequantize
(
ref_y
,
inv_scale
,
dtype
)
# Reference dynamic quantizaton
...
...
@@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Static quantization
y
,
_
=
scaled_fp8_quant
(
x
,
inv_scale
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
)
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Padding
y
,
_
=
scaled_fp8_quant
(
x
,
inv_scale
,
batch_dim_padding
=
17
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
,
batch_dim_padding
=
17
)
assert
y
.
shape
[
0
]
==
17
assert
torch
.
allclose
(
ref_y
,
...
...
vllm/_custom_ops.py
View file @
47f0954a
...
...
@@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k
,
is_k_full
)
# fp8 marlin
def
fp8_marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
fp8_marlin_gemm
(
a
,
b_q_weight
,
b_scales
,
workspace
,
num_bits
,
size_m
,
size_n
,
size_k
)
# fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
47f0954a
...
...
@@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQMarlinState
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
pack_fp8_to_int32
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
...
...
@@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
8
9
return
8
0
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
...
@@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
def
_create_scale_param
(
self
,
scale_name
:
str
,
...
...
@@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
...
...
@@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if
self
.
use_marlin
:
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
prepare_layer_for_marlin
(
self
,
layer
:
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
assert
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
layer
.
marlin_state
=
GPTQMarlinState
.
READY
device
=
layer
.
weight
.
device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
layer
.
weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
layer
.
weight
=
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
,
num_bits
=
8
,
)
layer
.
weight_scale
=
Parameter
(
marlin_scales
,
requires_grad
=
False
)
# Allocate marlin workspace
max_workspace_size
=
(
part_size_n
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
(
not
hasattr
(
layer
,
"process_after_load"
)
or
not
layer
.
process_after_load
):
...
...
@@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
input_scale
=
None
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
return
# If checkpoint is fp8, requantize the separately quantized logical
...
...
@@ -233,44 +309,72 @@ class Fp8LinearMethod(LinearMethodBase):
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
if
self
.
use_marlin
:
self
.
prepare_layer_for_marlin
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if
self
.
use_marlin
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out_shape
=
x
.
shape
[:
-
1
]
+
(
layer
.
output_size_per_partition
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
layer
.
weight
,
b_scales
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
)
return
output
.
reshape
(
out_shape
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
bias
=
bias
,
)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
bias
=
bias
,
)
return
torch
.
narrow
(
output
,
0
,
0
,
x
.
shape
[
0
])
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
47f0954a
...
...
@@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor
,
quantize_weights
,
sort_weights
)
from
vllm.platforms
import
current_platform
__cuda_arch
=
current_platform
.
get_device_capability
()
MARLIN_TILE
=
16
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
capability
=
current_platform
.
get_device_capability
()
return
capability
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
...
...
@@ -223,3 +222,26 @@ class MarlinWorkspace:
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
shape
[
0
]
%
4
==
0
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
# Pack 4 uint8 values into one int32
packed
=
(
byte_tensor
[:,
0
].
to
(
torch
.
int32
)
|
(
byte_tensor
[:,
1
].
to
(
torch
.
int32
)
<<
8
)
|
(
byte_tensor
[:,
2
].
to
(
torch
.
int32
)
<<
16
)
|
(
byte_tensor
[:,
3
].
to
(
torch
.
int32
)
<<
24
))
return
packed
.
view
(
fp8_tensor
.
shape
[
0
]
//
4
,
*
fp8_tensor
.
shape
[
1
:]).
contiguous
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment