Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a0aa077
Unverified
Commit
0a0aa077
authored
Jan 09, 2026
by
Lucas Wilkinson
Committed by
GitHub
Jan 09, 2026
Browse files
[Quant] Make static quant support all group shapes (#30833)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
f9e2a75a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
339 additions
and
47 deletions
+339
-47
csrc/ops.h
csrc/ops.h
+4
-2
csrc/quantization/w8a8/fp8/common.cu
csrc/quantization/w8a8/fp8/common.cu
+178
-22
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+5
-2
tests/kernels/quantization/test_fp8_quant.py
tests/kernels/quantization/test_fp8_quant.py
+103
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+13
-4
vllm/model_executor/layers/quantization/input_quant_fp8.py
vllm/model_executor/layers/quantization/input_quant_fp8.py
+17
-11
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+19
-4
No files found.
csrc/ops.h
View file @
0a0aa077
...
...
@@ -2,6 +2,7 @@
#include <optional>
#include <torch/library.h>
#include <tuple>
#include "core/scalar_type.hpp"
...
...
@@ -346,8 +347,9 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
,
std
::
optional
<
std
::
tuple
<
int64_t
,
int64_t
>>
group_shape
=
std
::
nullopt
);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
);
...
...
csrc/quantization/w8a8/fp8/common.cu
View file @
0a0aa077
...
...
@@ -4,28 +4,77 @@
#include "quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <tuple>
namespace
vllm
{
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
scaled_fp8_quant_kernel_strided
(
// STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
// STRIDE_J_ZERO: true if scale_stride_j == 0 (per-tensor or per-token)
template
<
typename
scalar_t
,
typename
fp8_type
,
bool
STRIDE_I_ZERO
,
bool
STRIDE_J_ZERO
>
__global__
void
scaled_fp8_quant_kernel_strided_group_shape
(
fp8_type
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
float
*
__restrict__
scale
,
int
hidden_size
,
int64_t
in_row_stride
,
int64_t
out_row_stride
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
// one token per block
int64_t
out_row_stride
,
int
group_m
,
int
group_n
,
int64_t
scale_stride_i
,
int64_t
scale_stride_j
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
scalar_t
*
token_in
=
input
+
token_idx
*
in_row_stride
;
fp8_type
*
token_out
=
out
+
token_idx
*
out_row_stride
;
const
float
inv_scale
=
1.0
f
/
(
*
scale
);
vectorize_with_alignment
<
16
>
(
token_in
,
token_out
,
hidden_size
,
tid
,
blockDim
.
x
,
// Precompute row-level base offset for scale access (compile-time eliminated
// when STRIDE_I_ZERO)
const
int64_t
scale_row_base
=
STRIDE_I_ZERO
?
0
:
static_cast
<
int
>
(
token_idx
)
/
group_m
*
scale_stride_i
;
auto
get_inv_scale
=
[
&
](
int
gj
)
{
return
1.0
f
/
scale
[
scale_row_base
+
gj
*
scale_stride_j
];
};
int
cached_gj
=
-
1
;
float
cached_inv_scale
=
0.0
f
;
auto
get_inv_scale_cached
=
[
&
](
int
gj
)
{
if
(
gj
!=
cached_gj
)
{
cached_inv_scale
=
1.0
f
/
scale
[
scale_row_base
+
gj
*
scale_stride_j
];
cached_gj
=
gj
;
}
return
cached_inv_scale
;
};
constexpr
int
VEC_SIZE
=
16
;
// FP8 so vectorize to 128 bits
auto
scaled_fp8_conversion_vectorized
=
[
&
](
const
scalar_t
*
in
,
fp8_type
*
out
,
int
size
,
float
inv_scale
)
{
vectorize_with_alignment
<
VEC_SIZE
>
(
in
,
out
,
size
,
tid
,
blockDim
.
x
,
[
=
]
__device__
(
fp8_type
&
dst
,
const
scalar_t
&
src
)
{
dst
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
static_cast
<
float
>
(
src
),
inv_scale
);
});
};
if
(
STRIDE_J_ZERO
&&
hidden_size
%
VEC_SIZE
==
0
)
{
// Per-tensor or per-token: single scale per row, vectorize full row
scaled_fp8_conversion_vectorized
(
token_in
,
token_out
,
hidden_size
,
get_inv_scale
(
0
));
}
else
if
(
group_n
%
VEC_SIZE
==
0
)
{
// Multiple column groups with vectorization
const
int
num_groups_n
=
hidden_size
/
group_n
;
for
(
int
gj
=
0
;
gj
<
num_groups_n
;
gj
++
)
{
scaled_fp8_conversion_vectorized
(
token_in
+
gj
*
group_n
,
token_out
+
gj
*
group_n
,
group_n
,
get_inv_scale
(
gj
));
}
}
else
{
// Scalar path for small column groups (group_n < VEC_SIZE)
for
(
int
n
=
tid
;
n
<
hidden_size
;
n
+=
blockDim
.
x
)
{
const
int
gj
=
n
/
group_n
;
token_out
[
n
]
=
scaled_fp8_conversion
<
true
,
fp8_type
>
(
static_cast
<
float
>
(
token_in
[
n
]),
get_inv_scale_cached
(
gj
));
}
}
}
template
<
typename
scalar_t
,
typename
fp8_type
>
...
...
@@ -133,17 +182,116 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
}
// namespace vllm
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
scale
)
// [1]
torch
::
Tensor
const
&
scale
,
// various shapes
std
::
optional
<
std
::
tuple
<
int64_t
,
int64_t
>>
opt_group_shape
)
// optional explicit (group_m, group_n)
{
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
,
"last dimension of input must be contiguous"
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"last dimension of output must be contiguous"
);
const
int
hidden_size
=
input
.
size
(
-
1
);
const
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
int
hidden_size
=
input
.
size
(
-
1
);
// N (columns)
const
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
// M (rows)
// Determine group_m, group_n, and scale strides from scale shape
// Scale indexing: scale[gi * scale_stride_j + gj * scale_stride_i]
// where gi = m / group_m, gj = n / group_n
int
group_m
,
group_n
;
int64_t
scale_stride_i
,
scale_stride_j
;
if
(
scale
.
dim
()
==
0
||
scale
.
numel
()
==
1
)
{
// Per-tensor: one scale for the entire tensor
group_m
=
num_tokens
;
group_n
=
hidden_size
;
scale_stride_i
=
0
;
scale_stride_j
=
0
;
}
else
if
(
scale
.
dim
()
==
1
)
{
// 1D scale: require explicit group_shape to disambiguate per-channel vs
// per-token (avoids edge case where num_tokens == hidden_size)
TORCH_CHECK
(
opt_group_shape
.
has_value
(),
"1D scale requires explicit group_shape to disambiguate "
"per-channel vs per-token quantization. "
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
"-1) for per-token."
);
const
auto
&
[
opt_group_m
,
opt_group_n
]
=
opt_group_shape
.
value
();
group_m
=
opt_group_m
==
-
1
?
num_tokens
:
static_cast
<
int
>
(
opt_group_m
);
group_n
=
opt_group_n
==
-
1
?
hidden_size
:
static_cast
<
int
>
(
opt_group_n
);
// Validate the explicit group shape matches the 1D scale
const
int64_t
scale_len
=
scale
.
numel
();
const
int64_t
expected_scale_m
=
num_tokens
/
group_m
;
const
int64_t
expected_scale_n
=
hidden_size
/
group_n
;
const
int64_t
expected_scale_numel
=
expected_scale_m
*
expected_scale_n
;
TORCH_CHECK
(
scale_len
==
expected_scale_numel
,
"1D scale length ("
,
scale_len
,
") does not match expected size ("
,
expected_scale_numel
,
") for group_shape ("
,
opt_group_m
,
", "
,
opt_group_n
,
") with input shape ("
,
num_tokens
,
", "
,
hidden_size
,
")"
);
// For 1D scale, determine strides based on which dim is trivial
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
// where gi = m / group_m (row group), gj = n / group_n (col group)
if
(
expected_scale_m
==
1
)
{
// Per-channel style: one scale in M dim, scale varies along N
// gi = 0 always, gj varies, so stride_1 traverses the scale
scale_stride_i
=
0
;
scale_stride_j
=
scale
.
stride
(
0
);
}
else
if
(
expected_scale_n
==
1
)
{
// Per-token style: one scale in N dim, scale varies along M
// gj = 0 always, gi varies, so stride_0 traverses the scale
scale_stride_i
=
scale
.
stride
(
0
);
scale_stride_j
=
0
;
}
else
{
TORCH_CHECK
(
false
,
"1D scale can only be used when one of the scale dimensions is 1. "
"For 2D group scaling, use a 2D scale tensor."
);
}
}
else
if
(
scale
.
dim
()
==
2
)
{
// 2D scale: infer group sizes from scale dimensions (or use explicit if
// provided)
const
int64_t
scale_size_0
=
scale
.
size
(
0
);
const
int64_t
scale_size_1
=
scale
.
size
(
1
);
TORCH_CHECK
(
num_tokens
%
scale_size_0
==
0
,
"num_tokens ("
,
num_tokens
,
") must be divisible by scale.size(0) ("
,
scale_size_0
,
")"
);
TORCH_CHECK
(
hidden_size
%
scale_size_1
==
0
,
"hidden_size ("
,
hidden_size
,
") must be divisible by scale.size(1) ("
,
scale_size_1
,
")"
);
// Infer from 2D scale shape
int
inferred_group_m
=
num_tokens
/
scale_size_0
;
int
inferred_group_n
=
hidden_size
/
scale_size_1
;
// Use explicit if provided, otherwise use inferred
if
(
opt_group_shape
.
has_value
())
{
const
auto
&
[
opt_group_m
,
opt_group_n
]
=
opt_group_shape
.
value
();
group_m
=
opt_group_m
==
-
1
?
num_tokens
:
static_cast
<
int
>
(
opt_group_m
);
group_n
=
opt_group_n
==
-
1
?
hidden_size
:
static_cast
<
int
>
(
opt_group_n
);
// Validate explicit matches inferred
TORCH_CHECK
(
group_m
==
inferred_group_m
&&
group_n
==
inferred_group_n
,
"Explicit group_shape ("
,
opt_group_m
,
", "
,
opt_group_n
,
") does not match inferred group shape ("
,
inferred_group_m
,
", "
,
inferred_group_n
,
") from 2D scale tensor shape ("
,
scale_size_0
,
", "
,
scale_size_1
,
")"
);
}
else
{
group_m
=
inferred_group_m
;
group_n
=
inferred_group_n
;
}
scale_stride_i
=
scale
.
stride
(
0
);
scale_stride_j
=
scale
.
stride
(
1
);
}
else
{
TORCH_CHECK
(
false
,
"scale must be 0D, 1D, or 2D tensor, but got "
,
scale
.
dim
(),
"D"
);
}
const
int
block_size
=
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
block_size
);
...
...
@@ -153,15 +301,23 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// Dispatch to template-specialized kernel based on stride pattern
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel_scalar_type"
,
[
&
]
{
VLLM_DISPATCH_FP8_TYPES
(
out
.
scalar_type
(),
"scaled_fp8_quant_kernel_fp8_type"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel_strided
<
scalar_t
,
fp8_t
>
VLLM_DISPATCH_BOOL
(
scale_stride_i
==
0
,
S0_ZERO
,
[
&
]
{
VLLM_DISPATCH_BOOL
(
scale_stride_j
==
0
,
S1_ZERO
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel_strided_group_shape
<
scalar_t
,
fp8_t
,
S0_ZERO
,
S1_ZERO
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
hidden_size
,
in_row_stride
,
out_row_stride
);
out_row_stride
,
group_m
,
group_n
,
scale_stride_i
,
scale_stride_j
);
});
});
});
});
}
...
...
csrc/torch_bindings.cpp
View file @
0a0aa077
...
...
@@ -599,9 +599,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token.
ops
.
def
(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale
) ->
"
"()"
);
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale
,
"
"
(int, int)? group_shape=None) ->
()"
);
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
...
...
tests/kernels/quantization/test_fp8_quant.py
View file @
0a0aa077
...
...
@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import (
ref_dynamic_per_token_quant
,
)
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_quantize
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -21,10 +25,18 @@ SEEDS = [0]
def
opcheck_fp8_quant
(
output
,
input
,
scale
=
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
False
output
,
input
,
scale
=
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
False
,
group_shape
=
None
,
):
if
scale
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_fp8_quant
,
(
output
,
input
,
scale
))
opcheck
(
torch
.
ops
.
_C
.
static_scaled_fp8_quant
,
(
output
,
input
,
scale
,
group_shape
),
)
elif
use_per_token_if_dynamic
:
scale
=
torch
.
empty
(
(
input
.
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
...
...
@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None:
ops_out
=
ops_out
.
to
(
dtype
=
dtype
)
torch
.
testing
.
assert_close
(
ref_out
,
ops_out
)
# Test static FP8 quantization with 2D group scales
GROUP_SHAPES_2D
=
[
(
-
1
,
-
1
),
# Per-tensor
(
-
1
,
1
),
# Per-channel
(
1
,
-
1
),
# Per-token
(
-
1
,
128
),
# Per-head quantization
(
1
,
128
),
# DeepSeek-style per-token-per-group (group_m=1, group_n=128)
(
128
,
128
),
# DeepSeek-style block quantization
(
1
,
64
),
# Smaller group size
(
1
,
16
),
# Small group (scalar path in kernel)
(
4
,
256
),
# Non-trivial both dimensions
]
# Use sizes divisible by all group shapes
NUM_TOKENS_GROUP
=
[
128
,
512
]
HIDDEN_SIZES_GROUP
=
[
256
,
1024
,
2048
]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS_GROUP
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES_GROUP
)
@
pytest
.
mark
.
parametrize
(
"group_shape"
,
GROUP_SHAPES_2D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_static_fp8_quant_group_2d
(
num_tokens
:
int
,
hidden_size
:
int
,
group_shape
:
tuple
[
int
,
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
"""Test static FP8 quantization with 2D group scales using scaled_quantize."""
# Normalize group_shape (-1 means full extent)
norm_group_m
=
num_tokens
if
group_shape
[
0
]
==
-
1
else
group_shape
[
0
]
norm_group_n
=
hidden_size
if
group_shape
[
1
]
==
-
1
else
group_shape
[
1
]
# Skip if sizes are not divisible by group shape
if
num_tokens
%
norm_group_m
!=
0
or
hidden_size
%
norm_group_n
!=
0
:
pytest
.
skip
(
f
"Skipping: (
{
num_tokens
}
,
{
hidden_size
}
) not divisible by "
f
"group_shape (
{
group_shape
[
0
]
}
,
{
group_shape
[
1
]
}
)"
)
current_platform
.
seed_everything
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
scale
=
scaled_quantize
(
x
,
group_shape
,
FP8_DTYPE
,
compute_dtype
=
torch
.
float32
)
ops_out
,
ops_scale
=
ops
.
scaled_fp8_quant
(
x
,
scale
=
scale
,
group_shape
=
group_shape
)
torch
.
testing
.
assert_close
(
scale
,
ops_scale
)
torch
.
testing
.
assert_close
(
ref_out
.
float
(),
ops_out
.
float
(),
rtol
=
0.12
,
atol
=
0.0
)
opcheck_fp8_quant
(
ops_out
,
x
,
scale
=
scale
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS_GROUP
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES_GROUP
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"group_shape"
,
[(
1
,
-
1
),
(
-
1
,
1
)])
# per-token, per-channel
@
torch
.
inference_mode
()
def
test_static_fp8_quant_1d_scale
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
group_shape
:
tuple
[
int
,
int
],
)
->
None
:
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
current_platform
.
seed_everything
(
seed
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
scale_2d
=
scaled_quantize
(
x
,
group_shape
,
FP8_DTYPE
,
compute_dtype
=
torch
.
float32
)
# Flatten scale to 1D for testing 1D scale path
scale_1d
=
scale_2d
.
flatten
()
ops_out
,
ops_scale
=
ops
.
scaled_fp8_quant
(
x
,
scale
=
scale_1d
,
group_shape
=
group_shape
)
torch
.
testing
.
assert_close
(
scale_1d
,
ops_scale
)
torch
.
testing
.
assert_close
(
ref_out
.
float
(),
ops_out
.
float
(),
rtol
=
0.12
,
atol
=
0.0
)
opcheck_fp8_quant
(
ops_out
,
x
,
scale
=
scale_1d
,
group_shape
=
group_shape
)
vllm/_custom_ops.py
View file @
0a0aa077
...
...
@@ -1752,6 +1752,7 @@ def scaled_fp8_quant(
scale_ub
:
torch
.
Tensor
|
None
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
torch
.
Tensor
|
None
=
None
,
group_shape
:
tuple
[
int
,
int
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
...
...
@@ -1763,14 +1764,23 @@ def scaled_fp8_quant(
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
scale: Optional scaling factor for the FP8 quantization. Supports:
- 0D or [1]: per-tensor scaling
- 1D: requires explicit group_shape to disambiguate per-channel
vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
- 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
group_shape: Optional tuple (group_m, group_n) specifying the group
shape for static quantization. Use -1 for "full extent" (e.g.,
(-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
Required for 1D scales; optional for 2D scales.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
...
...
@@ -1799,8 +1809,7 @@ def scaled_fp8_quant(
scale
=
torch
.
empty
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
,
group_shape
)
return
output
,
scale
...
...
vllm/model_executor/layers/quantization/input_quant_fp8.py
View file @
0a0aa077
...
...
@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
get_fp8_min_max
,
group_broadcast
,
)
from
vllm.platforms
import
current_platform
...
...
@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@
CustomOp
.
register
(
"quant_fp8"
)
class
QuantFP8
(
CustomOp
):
"""
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
Quantize input tensor to FP8 (per-tensor, per-token,
per-channel,
or per-group).
This CustomOp supports both static and dynamic quantization.
"""
...
...
@@ -57,14 +58,14 @@ class QuantFP8(CustomOp):
self
.
is_group_quant
=
group_shape
.
is_per_group
()
if
self
.
is_group_quant
:
assert
not
static
,
"Group quantization only supports dynamic mode"
self
.
group_size
=
group_shape
.
col
else
:
assert
group_shape
in
{
GroupShape
.
PER_TOKEN
,
GroupShape
.
PER_TENSOR
}
assert
not
static
or
group_shape
==
GroupShape
.
PER_TENSOR
,
(
"Only per-tensor scales supported for static quantization."
)
self
.
use_per_token_if_dynamic
=
group_shape
==
GroupShape
.
PER_TOKEN
if
not
static
:
assert
group_shape
in
(
GroupShape
.
PER_TOKEN
,
GroupShape
.
PER_TENSOR
),
(
"Only per-token or per-tensor scales are supported for dynamic "
"non-group quantization."
)
def
forward_cuda
(
self
,
...
...
@@ -72,8 +73,8 @@ class QuantFP8(CustomOp):
scale
:
torch
.
Tensor
|
None
=
None
,
scale_ub
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
is_group_quant
:
assert
scale
is
None
,
"
G
roup quantization
is always dynamic
"
if
self
.
is_group_quant
and
not
self
.
static
:
assert
scale
is
None
,
"
Dynamic g
roup quantization
does not use scale
"
from
vllm.model_executor.layers.quantization.utils
import
fp8_utils
return
fp8_utils
.
per_token_group_quant_fp8
(
...
...
@@ -90,12 +91,14 @@ class QuantFP8(CustomOp):
and
self
.
group_shape
==
GroupShape
.
PER_TOKEN
and
scale_ub
.
numel
()
==
1
)
return
ops
.
scaled_fp8_quant
(
x
,
scale
,
num_token_padding
=
self
.
num_token_padding
,
scale_ub
=
scale_ub
,
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
,
group_shape
=
self
.
group_shape
if
self
.
static
else
None
,
)
def
forward_hip
(
...
...
@@ -131,8 +134,8 @@ class QuantFP8(CustomOp):
scale
:
torch
.
Tensor
|
None
=
None
,
scale_ub
:
torch
.
Tensor
|
None
=
None
,
):
if
self
.
is_group_quant
:
assert
scale
is
None
,
"
G
roup quantization
is always dynamic
"
if
self
.
is_group_quant
and
not
self
.
static
:
assert
scale
is
None
,
"
Dynamic g
roup quantization
does not use scale
"
return
self
.
_quantize_group_native
(
x
)
assert
(
scale
is
not
None
)
==
self
.
static
...
...
@@ -155,7 +158,10 @@ class QuantFP8(CustomOp):
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
out
=
x
.
to
(
torch
.
float32
)
*
scale
.
reciprocal
()
out
=
(
x
.
to
(
torch
.
float32
)
*
group_broadcast
(
scale
.
to
(
torch
.
float32
),
x
.
shape
[
-
2
:]).
reciprocal
()
)
out
=
out
.
clamp
(
_FP8_MIN
,
_FP8_MAX
).
to
(
_FP8_DTYPE
)
# This currently generates an extra Triton kernel in compilation.
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
0a0aa077
...
...
@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# with an extent of 1, since this can be done implicitly by pytorch
def
group_broadcast
(
t
,
shape
):
for
i
,
s
in
enumerate
(
shape
):
if
t
.
shape
[
i
]
!=
s
and
t
.
shape
[
i
]
!=
1
:
assert
s
%
t
.
shape
[
i
]
==
0
# If tensor has fewer dimensions than target shape, treat missing
# dimensions as size 1 (standard PyTorch broadcasting behavior)
t_dim_size
=
t
.
shape
[
i
]
if
i
<
t
.
ndim
else
1
if
t_dim_size
!=
s
and
t_dim_size
!=
1
:
assert
s
%
t_dim_size
==
0
t
=
(
t
.
unsqueeze
(
i
+
1
)
.
expand
(
*
t
.
shape
[:
i
+
1
],
s
//
t
.
shape
[
i
]
,
*
t
.
shape
[
i
+
1
:])
.
expand
(
*
t
.
shape
[:
i
+
1
],
s
//
t
_dim_size
,
*
t
.
shape
[
i
+
1
:])
.
flatten
(
i
,
i
+
1
)
)
return
t
...
...
@@ -180,7 +183,16 @@ def scaled_quantize(
x
:
torch
.
Tensor
,
group_shape
:
GroupShape
,
quant_dtype
:
torch
.
dtype
,
compute_dtype
:
torch
.
dtype
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
x: Input tensor to quantize
group_shape: Shape of quantization groups
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
compute_dtype: Optional dtype for intermediate computations.
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape
=
_normalize_quant_group_shape
(
x
,
group_shape
)
assert
quant_dtype
.
is_floating_point
,
(
"currently `scaled_quantize` only supports floating point dtypes "
...
...
@@ -189,11 +201,14 @@ def scaled_quantize(
finfo
=
torch
.
finfo
(
quant_dtype
)
# Convert to compute dtype if specified
x_compute
=
x
if
compute_dtype
is
None
else
x
.
to
(
compute_dtype
)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert
x
.
ndim
==
2
assert
x
.
shape
[
0
]
%
group_shape
[
0
]
==
0
and
x
.
shape
[
1
]
%
group_shape
[
1
]
==
0
blk_m
,
blk_n
=
x
.
shape
[
0
]
//
group_shape
[
0
],
x
.
shape
[
1
]
//
group_shape
[
1
]
x_blkd
=
x
.
reshape
(
blk_m
,
group_shape
[
0
],
blk_n
,
group_shape
[
1
])
x_blkd
=
x
_compute
.
reshape
(
blk_m
,
group_shape
[
0
],
blk_n
,
group_shape
[
1
])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd
=
x_blkd
.
permute
(
0
,
2
,
1
,
3
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment