Unverified Commit 68af5f6c authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[AMD][FP8][BugFix] Remove V1 check in arg_utils.py for FP8 since it is not necessary (#17215)


Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 8de2901f
......@@ -1368,23 +1368,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
if current_platform.is_rocm():
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
load_config = self.create_load_config()
quantization_config = VllmConfig.get_quantization_config(
model_config, load_config)
if isinstance(quantization_config, Fp8Config):
_raise_or_fallback(feature_name="fp8 for ROCm",
recommend_to_remove=False)
return False
from vllm.model_executor.layers.quantization.quark.quark import (
QuarkConfig)
if isinstance(quantization_config, QuarkConfig
) and quantization_config.has_fp8_layer_weights():
_raise_or_fallback(feature_name="Quark fp8 for ROCm",
recommend_to_remove=False)
# No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
......
......@@ -307,18 +307,6 @@ class QuarkConfig(QuantizationConfig):
# If no matches, return None
return None
def has_fp8_layer_weights(self):
layer_quant_config = self.quant_config.get("layer_quant_config")
to_dict = lambda obj: cast(Dict[str, Any], obj) or {}
return any([
'fp8' in cast(
str,
to_dict(
to_dict(to_dict(layer_quant_config).get(layer_name)).get(
"weight")).get("dtype"))
for layer_name in ["*v_proj", "*k_proj", "*q_proj"]
])
class QuarkLinearMethod(LinearMethodBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment