Unverified Commit 4746aaea authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: support fb fp8 (#9462)

parent 10d34f74
...@@ -16,7 +16,6 @@ try: ...@@ -16,7 +16,6 @@ try:
) )
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config, GPTQMarlin24Config,
...@@ -37,9 +36,9 @@ except ImportError as e: ...@@ -37,9 +36,9 @@ except ImportError as e:
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = ( AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
ExpertsInt8Config ExpertsInt8Config
) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = ( ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
Int8TpuConfig DummyConfig
) = DummyConfig )
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
...@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig, CompressedTensorsConfig,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config, ModelOptFp4Config,
...@@ -85,6 +85,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -85,6 +85,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"qoq": QoQConfig, "qoq": QoQConfig,
"w4afp8": W4AFp8Config, "w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config, "petit_nvfp4": PetitNvFp4Config,
"fbgemm_fp8": FBGEMMFp8Config,
} }
...@@ -109,7 +110,6 @@ VLLM_QUANTIZATION_METHODS = { ...@@ -109,7 +110,6 @@ VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig, "tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig, "marlin": MarlinConfig,
"gguf": GGUFConfig, "gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -16,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -16,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
can_auto_enable_marlin_fp8, can_auto_enable_marlin_fp8,
...@@ -28,7 +29,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import ( ...@@ -28,7 +29,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
) )
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter
from sglang.srt.utils import get_bool_env_var, is_cuda, is_fp8_fnuz from sglang.srt.utils import get_bool_env_var, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
...@@ -88,6 +89,9 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -88,6 +89,9 @@ class FBGEMMFp8Config(QuantizationConfig):
return FBGEMMFp8LinearMethod(self) return FBGEMMFp8LinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class FBGEMMFp8LinearMethod(LinearMethodBase): class FBGEMMFp8LinearMethod(LinearMethodBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment