Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db23fcac
Commit
db23fcac
authored
Jan 12, 2026
by
SAC_fanth
Browse files
适配block和channel fp8
parent
3f5983bf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
80 deletions
+92
-80
vllm/_custom_ops.py
vllm/_custom_ops.py
+63
-61
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+6
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+3
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+19
-17
No files found.
vllm/_custom_ops.py
View file @
db23fcac
...
@@ -13,7 +13,8 @@ from vllm.scalar_type import ScalarType
...
@@ -13,7 +13,8 @@ from vllm.scalar_type import ScalarType
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
from
lmslim
import
quant_tools
from
lmslim.layers.gemm.fp8_utils
import
per_token_quant_fp8
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
try
:
...
@@ -1692,66 +1693,67 @@ def scaled_fp4_experts_quant(
...
@@ -1692,66 +1693,67 @@ def scaled_fp4_experts_quant(
# fp8
# fp8
# def scaled_fp8_quant(
def
scaled_fp8_quant
(
# input: torch.Tensor,
input
:
torch
.
Tensor
,
# scale: Optional[torch.Tensor] = None,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
# num_token_padding: Optional[int] = None,
num_token_padding
:
Optional
[
int
]
=
None
,
# scale_ub: Optional[torch.Tensor] = None,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
# use_per_token_if_dynamic: bool = False,
use_per_token_if_dynamic
:
bool
=
False
,
# output: Optional[torch.Tensor] = None,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
# ) -> tuple[torch.Tensor, torch.Tensor]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# """
"""
# Quantize input tensor to FP8 and return quantized tensor and scale.
Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
optional padding of the output tensors for downstream kernels that
# will benefit from padding.
will benefit from padding.
# Args:
Args:
# input: The input tensor to be quantized to FP8
input: The input tensor to be quantized to FP8
# scale: Optional scaling factor for the FP8 quantization
scale: Optional scaling factor for the FP8 quantization
# scale_ub: Optional upper bound for scaling factor in dynamic
scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
per token case
# num_token_padding: If specified, pad the first dimension
num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
in the dynamic quantization case.
# Returns:
Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
scaling factor.
# """
"""
# # This code assumes batch_dim and num_tokens are flattened
# This code assumes batch_dim and num_tokens are flattened
# assert (input.ndim == 2)
assert
(
input
.
ndim
==
2
)
# shape: Union[tuple[int, int], torch.Size] = input.shape
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
# if num_token_padding:
if
num_token_padding
:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
# if output is None:
if
output
is
None
:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
# else:
else
:
# assert num_token_padding is None, \
assert
num_token_padding
is
None
,
\
# "padding not supported if output passed in"
"padding not supported if output passed in"
# assert output.dtype == out_dtype
assert
output
.
dtype
==
out_dtype
# if scale is None:
if
scale
is
None
:
# if use_per_token_if_dynamic:
if
use_per_token_if_dynamic
:
# scale = torch.empty((shape[0], 1),
scale
=
torch
.
empty
((
shape
[
0
],
1
),
# device=input.device,
device
=
input
.
device
,
# dtype=torch.float32)
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
# output, input.contiguous(), scale, scale_ub)
# else:
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
# scale = torch.zeros(1, device=input.device, dtype=torch.float32)
else
:
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
# else:
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
# assert scale.numel() == 1, f"{scale.shape}"
else
:
# torch.ops._C.static_scaled_fp8_quant(output, input, scale)
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
# return output, scale
return
output
,
scale
# gptq allspark
# gptq allspark
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
db23fcac
...
@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -654,6 +654,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
=
False
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
db23fcac
...
@@ -140,7 +140,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -140,7 +140,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def
apply_weights
(
self
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
return
self
.
fp8_linear
.
apply
(
input
=
x
,
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
db23fcac
...
@@ -857,7 +857,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -857,7 +857,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
if
enable_eplb
:
assert
expert_load_view
is
not
None
assert
expert_load_view
is
not
None
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
db23fcac
...
@@ -11,7 +11,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
...
@@ -11,7 +11,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.fp8_utils
import
triton_scaled_mm_fp8
# Input scaling factors are no longer optional in _scaled_mm starting
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
None
TORCH_DEVICE_IDENTITY
=
None
...
@@ -278,25 +278,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
...
@@ -278,25 +278,27 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
# GEMM
# GEMM
# This computes C = (X * W).
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
# Output in fp32 to allow subsequent ops to happen in-place
output
=
torch
.
_scaled_mm
(
qinput
,
qinput
=
qinput
.
view
(
-
1
,
qinput
.
shape
[
-
1
])
output
=
triton_scaled_mm_fp8
(
qinput
,
weight
,
weight
,
scale_a
=
TORCH_DEVICE_IDENTITY
,
scale_a
=
scale_a
,
scale_b
=
TORCH_DEVICE_IDENTITY
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float32
)
out_dtype
=
out_dtype
,
bias
=
bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
# for torch < 2.5 and a single value in torch >= 2.5
if
type
(
output
)
is
tuple
and
len
(
output
)
==
2
:
#
if type(output) is tuple and len(output) == 2:
output
=
output
[
0
]
#
output = output[0]
# Unpad (undo num_token_padding)
#
#
Unpad (undo num_token_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input_2d
.
shape
[
0
])
#
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale
=
torch
.
narrow
(
scale_a
,
0
,
0
,
input_2d
.
shape
[
0
])
#
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
#
# DQ
#
#
DQ
# C = sw * sx * (X * W) + bias
#
#
C = sw * sx * (X * W) + bias
output
=
output
*
x_scale
*
scale_b
.
t
()
#
output = output * x_scale * scale_b.t()
if
bias
is
not
None
:
#
if bias is not None:
output
=
output
+
bias
#
output = output + bias
return
output
.
to
(
out_dtype
).
view
(
*
output_shape
)
return
output
.
view
(
*
output_shape
)
def
dispatch_w8a8_scaled_mm
(
def
dispatch_w8a8_scaled_mm
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment