Unverified Commit fd685cfd authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Quantization`] Add str to enum conversion for AWQ (#27320)



* add str to enum conversion

* fixup

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 184f60dc
......@@ -44,6 +44,16 @@ class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
@staticmethod
def from_str(version: str):
version = version.lower()
if version == "gemm":
return AWQLinearVersion.GEMM
elif version == "gemv":
return AWQLinearVersion.GEMV
else:
raise ValueError(f"Unknown AWQLinearVersion {version}")
class AwqBackendPackingMethod(str, Enum):
AUTOAWQ = "autoawq"
......@@ -566,6 +576,7 @@ class AwqConfig(QuantizationConfigMixin):
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
)
self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
......
......@@ -47,6 +47,13 @@ class AwqConfigTest(unittest.TestCase):
with self.assertRaises(ValueError):
AwqConfig(bits=4, backend="")
# These should work fine
_ = AwqConfig(bits=4, version="GEMM")
_ = AwqConfig(bits=4, version="gemm")
with self.assertRaises(ValueError):
AwqConfig(bits=4, backend="unexisting-backend")
# LLMAWQ does not work on a T4
with self.assertRaises(ValueError):
AwqConfig(bits=4, backend="llm-awq")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment