Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
84d96b3a
Unverified
Commit
84d96b3a
authored
Dec 06, 2024
by
Yineng Zhang
Committed by
GitHub
Dec 06, 2024
Browse files
Move FP8 to SGLang (#2370)
Co-authored-by:
HaiShaw
<
hixiao@gmail.com
>
parent
3d32e4a3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
561 additions
and
2 deletions
+561
-2
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-2
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+559
-0
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
84d96b3a
...
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
GPTQMarlinConfig
...
...
@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
...
...
@@ -100,13 +100,13 @@ def fp8_moe_apply(
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
"""Enhanced get_quant_method for FP8 config."""
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.fp8
import
Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
,
)
from
sglang.srt.layers.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
...
...
python/sglang/srt/layers/quantization/fp8.py
0 → 100644
View file @
84d96b3a
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
is_layer_skipped
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.fused_moe_triton
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
from
sglang.srt.layers.linear
import
LinearMethodBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.utils
import
(
get_bool_env_var
,
is_hip
,
print_warning_once
,
set_weight_attrs
,
)
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
logger
=
logging
.
getLogger
(
__name__
)
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if
activation_scheme
not
in
ACTIVATION_SCHEMES
:
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"Fp8Config"
:
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
"fp8"
in
quant_method
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
,
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
Fp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
get_bool_env_var
(
"SGLANG_FORCE_FP8_MARLIN"
)
# Disable marlin for ROCm
if
is_hip
():
self
.
use_marlin
=
False
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
assert
weight_scale
.
numel
()
==
1
weight_scale
=
convert_to_channelwise
(
weight_scale
.
expand
(
len
(
layer
.
logical_widths
)),
layer
.
logical_widths
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
weight
=
layer
.
weight
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
,
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
use_marlin
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
False
,
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
,
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
)
w2_weight
[
expert
,
:,
:],
layer
.
w2_weight_scale
[
expert
]
=
(
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
(
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
)
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
super
().
__init__
(
quant_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment