"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "459d9b38d3bc03eee8b5a7abad889fc3c6be8527"
Unverified Commit e0cbad4e authored by Satyajith Chilappagari's avatar Satyajith Chilappagari Committed by GitHub
Browse files

[Neuron] Support quantization on neuron (#18283)


Signed-off-by: default avatarSatyajith Chilappagari <satchill@amazon.com>
parent b48d5cca
# SPDX-License-Identifier: Apache-2.0
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
def test_get_supported_act_dtypes():
neuron_quant_config = NeuronQuantConfig()
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
target_list = ["any_dtype1", "any_dtype2"]
for dtype in target_list:
assert dtype in supported_act_dtypes
......@@ -13,6 +13,12 @@ from vllm.model_executor.layers.quantization.base_config import (
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
class AlwaysSupportedDtypes(list):
def __contains__(self, item):
return True
class NeuronQuantConfig(QuantizationConfig):
"""Int8 Quantization Config class for Neuron Backend."""
......@@ -35,7 +41,8 @@ class NeuronQuantConfig(QuantizationConfig):
return "neuron_quant"
def get_supported_act_dtypes(self) -> list[str]:
return SUPPORTED_QUANT_DTYPE_LIST
# Neuron implements custom handling logic for quantization support
return AlwaysSupportedDtypes()
@classmethod
def get_min_capability(cls) -> int:
......
......@@ -28,7 +28,7 @@ class NeuronPlatform(Platform):
device_name: str = "neuron"
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant"]
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment