Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
572cd426
Commit
572cd426
authored
Dec 15, 2025
by
zhuwenwen
Browse files
restore the initial fp8 related implementation
parent
c441dda9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
143 additions
and
145 deletions
+143
-145
CMakeLists.txt
CMakeLists.txt
+3
-3
cmake/utils.cmake
cmake/utils.cmake
+1
-1
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+0
-2
csrc/ops.h
csrc/ops.h
+17
-17
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+1
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+17
-17
vllm/_custom_ops.py
vllm/_custom_ops.py
+60
-60
vllm/compilation/fix_functionalization.py
vllm/compilation/fix_functionalization.py
+9
-9
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+25
-25
vllm/compilation/sequence_parallelism.py
vllm/compilation/sequence_parallelism.py
+10
-10
No files found.
CMakeLists.txt
View file @
572cd426
...
@@ -255,15 +255,15 @@ set(VLLM_EXT_SRC
...
@@ -255,15 +255,15 @@ set(VLLM_EXT_SRC
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
#
"csrc/layernorm_quant_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
"csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu"
# "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
#
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
#
"csrc/quantization/activation_kernels.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"
"csrc/custom_all_reduce.cu"
...
...
cmake/utils.cmake
View file @
572cd426
...
@@ -123,7 +123,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -123,7 +123,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DUSE_ROCM"
#
"-DENABLE_FP8"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-Werror=unused-variable"
"-Werror=unused-variable"
...
...
csrc/layernorm_quant_kernels.cu
View file @
572cd426
...
@@ -6,9 +6,7 @@
...
@@ -6,9 +6,7 @@
*/
*/
#include "type_convert.cuh"
#include "type_convert.cuh"
#ifndef USE_ROCM
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <torch/cuda.h>
...
...
csrc/ops.h
View file @
572cd426
...
@@ -221,15 +221,15 @@ void apply_repetition_penalties_(torch::Tensor& logits,
...
@@ -221,15 +221,15 @@ void apply_repetition_penalties_(torch::Tensor& logits,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
output_mask
,
const
torch
::
Tensor
&
repetition_penalties
);
const
torch
::
Tensor
&
repetition_penalties
);
//
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
//
torch::Tensor& weight, torch::Tensor& scale,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
//
double epsilon);
double
epsilon
);
//
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
//
torch::Tensor& input,
torch
::
Tensor
&
input
,
//
torch::Tensor& residual,
torch
::
Tensor
&
residual
,
//
torch::Tensor& weight,
torch
::
Tensor
&
weight
,
//
torch::Tensor& scale, double epsilon);
torch
::
Tensor
&
scale
,
double
epsilon
);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
...
@@ -258,8 +258,8 @@ void rotary_embedding_tgi(
...
@@ -258,8 +258,8 @@ void rotary_embedding_tgi(
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
//
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
//
torch::Tensor& scale);
torch
::
Tensor
&
scale
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
mul_and_silu
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
@@ -443,15 +443,15 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
...
@@ -443,15 +443,15 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
torch::Tensor const& scale);
torch
::
Tensor
const
&
scale
);
//
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
torch::Tensor& scale);
torch
::
Tensor
&
scale
);
//
void dynamic_per_token_scaled_fp8_quant(
void
dynamic_per_token_scaled_fp8_quant
(
//
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
//
std::optional<torch::Tensor> const& scale_ub);
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
...
...
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
572cd426
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
// TODO(luka/varun):refactor common.cuh to use this file instead
//
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
namespace
vllm
{
namespace
vllm
{
...
...
csrc/torch_bindings.cpp
View file @
572cd426
...
@@ -258,8 +258,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -258,8 +258,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"silu_and_mul"
,
torch
::
kCUDA
,
&
silu_and_mul
);
ops
.
impl
(
"silu_and_mul"
,
torch
::
kCUDA
,
&
silu_and_mul
);
// ops.def(
// ops.def(
//
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"
);
//
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
ops
.
impl
(
"silu_and_mul_quant"
,
torch
::
kCUDA
,
&
silu_and_mul_quant
);
ops
.
def
(
"mul_and_silu(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"mul_and_silu(Tensor! out, Tensor input) -> ()"
);
ops
.
impl
(
"mul_and_silu"
,
torch
::
kCUDA
,
&
mul_and_silu
);
ops
.
impl
(
"mul_and_silu"
,
torch
::
kCUDA
,
&
mul_and_silu
);
...
@@ -737,25 +737,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -737,25 +737,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
// Compute FP8 quantized tensor for given scaling factor.
//
ops.def(
ops
.
def
(
//
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//
"()");
"()"
);
//
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
//
ops.def(
ops
.
def
(
//
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//
"-> "
"-> "
//
"()");
"()"
);
//
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
//
ops.def(
ops
.
def
(
//
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
//
"Tensor! scale, Tensor? scale_ub) -> "
"Tensor! scale, Tensor? scale_ub) -> "
//
"()");
"()"
);
//
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
//
&dynamic_per_token_scaled_fp8_quant);
&
dynamic_per_token_scaled_fp8_quant
);
// Compute int8 quantized tensor for given scaling factor.
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
ops
.
def
(
...
...
vllm/_custom_ops.py
View file @
572cd426
...
@@ -1692,66 +1692,66 @@ def scaled_fp4_experts_quant(
...
@@ -1692,66 +1692,66 @@ def scaled_fp4_experts_quant(
# fp8
# fp8
#
def scaled_fp8_quant(
def
scaled_fp8_quant
(
#
input: torch.Tensor,
input
:
torch
.
Tensor
,
#
scale: Optional[torch.Tensor] = None,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
#
num_token_padding: Optional[int] = None,
num_token_padding
:
Optional
[
int
]
=
None
,
#
scale_ub: Optional[torch.Tensor] = None,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
#
use_per_token_if_dynamic: bool = False,
use_per_token_if_dynamic
:
bool
=
False
,
#
output: Optional[torch.Tensor] = None,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
#
) -> tuple[torch.Tensor, torch.Tensor]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
#
"""
"""
#
Quantize input tensor to FP8 and return quantized tensor and scale.
Quantize input tensor to FP8 and return quantized tensor and scale.
#
This function supports both static and dynamic quantization: If you
This function supports both static and dynamic quantization: If you
#
provide the scale, it will use static scaling and if you omit it,
provide the scale, it will use static scaling and if you omit it,
#
the scale will be determined dynamically. The function also allows
the scale will be determined dynamically. The function also allows
#
optional padding of the output tensors for downstream kernels that
optional padding of the output tensors for downstream kernels that
#
will benefit from padding.
will benefit from padding.
#
Args:
Args:
#
input: The input tensor to be quantized to FP8
input: The input tensor to be quantized to FP8
#
scale: Optional scaling factor for the FP8 quantization
scale: Optional scaling factor for the FP8 quantization
#
scale_ub: Optional upper bound for scaling factor in dynamic
scale_ub: Optional upper bound for scaling factor in dynamic
#
per token case
per token case
#
num_token_padding: If specified, pad the first dimension
num_token_padding: If specified, pad the first dimension
#
of the output to at least this value.
of the output to at least this value.
#
use_per_token_if_dynamic: Whether to do per_tensor or per_token
use_per_token_if_dynamic: Whether to do per_tensor or per_token
#
in the dynamic quantization case.
in the dynamic quantization case.
#
Returns:
Returns:
#
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
#
scaling factor.
scaling factor.
#
"""
"""
#
# This code assumes batch_dim and num_tokens are flattened
# This code assumes batch_dim and num_tokens are flattened
#
assert (input.ndim == 2)
assert
(
input
.
ndim
==
2
)
#
shape: Union[tuple[int, int], torch.Size] = input.shape
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
#
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
#
out_dtype: torch.dtype = current_platform.fp8_dtype()
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
#
if num_token_padding:
if
num_token_padding
:
#
shape = (max(num_token_padding, input.shape[0]), shape[1])
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
#
if output is None:
if
output
is
None
:
#
output = torch.empty(shape, device=input.device, dtype=out_dtype)
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
#
else:
else
:
#
assert num_token_padding is None, \
assert
num_token_padding
is
None
,
\
#
"padding not supported if output passed in"
"padding not supported if output passed in"
#
assert output.dtype == out_dtype
assert
output
.
dtype
==
out_dtype
#
if scale is None:
if
scale
is
None
:
#
if use_per_token_if_dynamic:
if
use_per_token_if_dynamic
:
#
scale = torch.empty((shape[0], 1),
scale
=
torch
.
empty
((
shape
[
0
],
1
),
#
device=input.device,
device
=
input
.
device
,
#
dtype=torch.float32)
dtype
=
torch
.
float32
)
#
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
#
output, input.contiguous(), scale, scale_ub)
output
,
input
.
contiguous
(),
scale
,
scale_ub
)
#
else:
else
:
#
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
#
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
#
else:
else
:
#
assert scale.numel() == 1, f"{scale.shape}"
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
#
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
#
return output, scale
return
output
,
scale
# gptq allspark
# gptq allspark
...
...
vllm/compilation/fix_functionalization.py
View file @
572cd426
...
@@ -62,9 +62,9 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -62,9 +62,9 @@ class FixFunctionalizationPass(VllmInductorPass):
elif
at_target
==
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
:
elif
at_target
==
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
:
mutated_args
=
{
1
:
'input'
,
2
:
'residual'
}
mutated_args
=
{
1
:
'input'
,
2
:
'residual'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
#
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
elif
at_target
==
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
:
# noqa: E501
#
mutated_args = {1: 'result', 2: 'residual'}
mutated_args
=
{
1
:
'result'
,
2
:
'residual'
}
#
self.defunctionalize(graph, node, mutated_args)
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
elif
at_target
==
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
:
# noqa: E501
elif
at_target
==
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
:
# noqa: E501
mutated_args
=
{
1
:
'result'
,
2
:
'scale'
,
3
:
'residual'
}
mutated_args
=
{
1
:
'result'
,
2
:
'scale'
,
3
:
'residual'
}
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
self
.
defunctionalize
(
graph
,
node
,
mutated_args
)
...
@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass):
...
@@ -83,12 +83,12 @@ class FixFunctionalizationPass(VllmInductorPass):
node
,
node
,
mutated_args
,
mutated_args
,
args
=
(
'result'
,
'input'
))
args
=
(
'result'
,
'input'
))
#
elif at_target == torch.ops._C.silu_and_mul_quant.default:
elif
at_target
==
torch
.
ops
.
_C
.
silu_and_mul_quant
.
default
:
#
mutated_args = {1: 'result'}
mutated_args
=
{
1
:
'result'
}
#
self.defunctionalize(graph,
self
.
defunctionalize
(
graph
,
#
node,
node
,
#
mutated_args,
mutated_args
,
#
args=('result', 'input', 'scale'))
args
=
(
'result'
,
'input'
,
'scale'
))
else
:
else
:
continue
# skip the count
continue
# skip the count
...
...
vllm/compilation/fusion.py
View file @
572cd426
...
@@ -82,17 +82,17 @@ class QuantKey(NamedTuple):
...
@@ -82,17 +82,17 @@ class QuantKey(NamedTuple):
f
"
{
'a'
if
not
self
.
symmetric
else
''
}
symmetric)"
)
f
"
{
'a'
if
not
self
.
symmetric
else
''
}
symmetric)"
)
#
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8StaticTensorSym
=
QuantKey
(
FP8_DTYPE
,
True
,
GroupShape
.
PER_TENSOR
,
True
)
#
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym
=
QuantKey
(
FP8_DTYPE
,
False
,
GroupShape
.
PER_TENSOR
,
True
)
#
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
kFp8DynamicTokenSym
=
QuantKey
(
FP8_DTYPE
,
False
,
GroupShape
.
PER_TOKEN
,
True
)
QUANT_OPS
:
dict
[
QuantKey
,
OpOverload
]
=
{
QUANT_OPS
:
dict
[
QuantKey
,
OpOverload
]
=
{
#
kFp8StaticTensorSym:
kFp8StaticTensorSym
:
#
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
# noqa: E501
#
kFp8DynamicTensorSym:
kFp8DynamicTensorSym
:
#
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
.
default
,
# noqa: E501
#
kFp8DynamicTokenSym:
kFp8DynamicTokenSym
:
#
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
.
default
,
# noqa: E501
}
}
...
@@ -111,14 +111,14 @@ class FusedRMSQuantKey(NamedTuple):
...
@@ -111,14 +111,14 @@ class FusedRMSQuantKey(NamedTuple):
FUSED_OPS
:
dict
[
FusedRMSQuantKey
,
OpOverload
]
=
{
FUSED_OPS
:
dict
[
FusedRMSQuantKey
,
OpOverload
]
=
{
#
FusedRMSQuantKey(kFp8StaticTensorSym, False):
FusedRMSQuantKey
(
kFp8StaticTensorSym
,
False
):
#
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
.
default
,
# noqa: E501
#
FusedRMSQuantKey(kFp8StaticTensorSym, True):
FusedRMSQuantKey
(
kFp8StaticTensorSym
,
True
):
#
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
.
default
,
# noqa: E501
#
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
FusedRMSQuantKey
(
kFp8DynamicTokenSym
,
False
):
#
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
,
# noqa: E501
#
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
FusedRMSQuantKey
(
kFp8DynamicTokenSym
,
True
):
#
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
torch
.
ops
.
_C
.
rms_norm_dynamic_per_token_quant
.
default
,
# noqa: E501
}
}
...
@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
...
@@ -586,23 +586,23 @@ class FusionPass(VllmInductorPass):
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
# Fuse rms_norm + static fp8 quant
# Fuse rms_norm + static fp8 quant
#
RMSNormStaticQuantPattern(epsilon,
RMSNormStaticQuantPattern
(
epsilon
,
#
FP8_DTYPE).register(self.patterns)
FP8_DTYPE
).
register
(
self
.
patterns
)
# Matches for patterns below have 2 or more outputs,
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
# so we need to process them manually (see process_matches)
# Fuse rms_norm + static fp8 quant
# Fuse rms_norm + static fp8 quant
#
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
FusedAddRMSNormStaticQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
#
self.patterns, self.record_match)
self
.
patterns
,
self
.
record_match
)
# Fuse rms_norm + dynamic per-token fp8 quant
# Fuse rms_norm + dynamic per-token fp8 quant
#
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
RMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
#
self.patterns, self.record_match)
self
.
patterns
,
self
.
record_match
)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
#
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
FusedAddRMSNormDynamicQuantPattern
(
epsilon
,
FP8_DTYPE
).
register
(
#
self.patterns, self.record_match)
self
.
patterns
,
self
.
record_match
)
# WARNING: This is a hack to clear the pattern matcher cache
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
# and allow multiple values of epsilon.
...
...
vllm/compilation/sequence_parallelism.py
View file @
572cd426
...
@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
...
@@ -444,16 +444,16 @@ class SequenceParallelismPass(VllmInductorPass):
for
epsilon
in
[
1e-5
,
1e-6
]:
for
epsilon
in
[
1e-5
,
1e-6
]:
# RMSNorm + Static FP8 quantization patterns
# RMSNorm + Static FP8 quantization patterns
#
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
fp8_quant_op
=
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
#
FirstAllReduceRMSNormStaticFP8Pattern(
FirstAllReduceRMSNormStaticFP8Pattern
(
#
epsilon, self.model_dtype, self.device,
epsilon
,
self
.
model_dtype
,
self
.
device
,
#
fp8_quant_op).register(self.patterns)
fp8_quant_op
).
register
(
self
.
patterns
)
#
MiddleAllReduceRMSNormStaticFP8Pattern(
MiddleAllReduceRMSNormStaticFP8Pattern
(
#
epsilon, self.model_dtype, self.device,
epsilon
,
self
.
model_dtype
,
self
.
device
,
#
fp8_quant_op).register(self.patterns)
fp8_quant_op
).
register
(
self
.
patterns
)
#
LastAllReduceRMSNormStaticFP8Pattern(
LastAllReduceRMSNormStaticFP8Pattern
(
#
epsilon, self.model_dtype, self.device,
epsilon
,
self
.
model_dtype
,
self
.
device
,
#
fp8_quant_op).register(self.patterns)
fp8_quant_op
).
register
(
self
.
patterns
)
# Normal RMSNorm patterns
# Normal RMSNorm patterns
FirstAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
FirstAllReduceRMSNormPattern
(
epsilon
,
self
.
model_dtype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment