"docs/vscode:/vscode.git/clone" did not exist on "d79d9eaaff90801668613a4e3d5d8a0004963f21"
Unverified Commit e0cbad4e authored by Satyajith Chilappagari's avatar Satyajith Chilappagari Committed by GitHub
Browse files

[Neuron] Support quantization on neuron (#18283)


Signed-off-by: default avatarSatyajith Chilappagari <satchill@amazon.com>
parent b48d5cca
# SPDX-License-Identifier: Apache-2.0
from vllm.model_executor.layers.quantization.neuron_quant import (
NeuronQuantConfig)
def test_get_supported_act_dtypes():
neuron_quant_config = NeuronQuantConfig()
supported_act_dtypes = neuron_quant_config.get_supported_act_dtypes()
target_list = ["any_dtype1", "any_dtype2"]
for dtype in target_list:
assert dtype in supported_act_dtypes
...@@ -13,6 +13,12 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -13,6 +13,12 @@ from vllm.model_executor.layers.quantization.base_config import (
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
class AlwaysSupportedDtypes(list):
def __contains__(self, item):
return True
class NeuronQuantConfig(QuantizationConfig): class NeuronQuantConfig(QuantizationConfig):
"""Int8 Quantization Config class for Neuron Backend.""" """Int8 Quantization Config class for Neuron Backend."""
...@@ -35,7 +41,8 @@ class NeuronQuantConfig(QuantizationConfig): ...@@ -35,7 +41,8 @@ class NeuronQuantConfig(QuantizationConfig):
return "neuron_quant" return "neuron_quant"
def get_supported_act_dtypes(self) -> list[str]: def get_supported_act_dtypes(self) -> list[str]:
return SUPPORTED_QUANT_DTYPE_LIST # Neuron implements custom handling logic for quantization support
return AlwaysSupportedDtypes()
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
......
...@@ -28,7 +28,7 @@ class NeuronPlatform(Platform): ...@@ -28,7 +28,7 @@ class NeuronPlatform(Platform):
device_name: str = "neuron" device_name: str = "neuron"
device_type: str = "neuron" device_type: str = "neuron"
ray_device_key: str = "neuron_cores" ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant"] supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment