Commit 8223f750 authored by luopl's avatar luopl
Browse files

feat: implement int8 quantization

parent 34bf6014
This diff is collapsed.
......@@ -33,6 +33,7 @@ QuantizationMethods = Literal[
"ipex",
"quark",
"moe_wna16",
"groupwise-quant",
"torchao",
"auto-round",
"rtn",
......@@ -120,6 +121,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .blockwise_int8 import BlockInt8Config
from .slimquant_w4a8 import SlimQuantW4A8Int8Config
from .slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
from .groupwise_quant import GroupwiseQuantConfig
method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
......@@ -152,6 +154,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": AutoRoundConfig,
"rtn": RTNConfig,
"blockwise_int8": BlockInt8Config,
"groupwise-quant": GroupwiseQuantConfig,
"slimquant_w4a8":SlimQuantW4A8Int8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
}
......
......@@ -16,16 +16,13 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import is_kme, SUPPORT_TC
from vllm.utils import SUPPORT_TC
if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
os.environ['VLLM_USE_FLASH_MLA'] = '0'
if is_kme:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
......@@ -190,7 +187,7 @@ class RocmPlatform(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf","groupwise-quant",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin","slimquant_w4a8_marlin"
]
......@@ -304,8 +301,6 @@ class RocmPlatform(Platform):
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
if is_kme:
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment