Unverified Commit feb2b768 authored by Jerry Zhang's avatar Jerry Zhang Committed by GitHub
Browse files

Add integration with gemlite weight only quant (#2528)

parent d95a5f5b
...@@ -21,7 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", ...@@ -21,7 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
"orjson", "outlines>=0.0.44,<0.1.0", "orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao>=0.7.0", "gemlite", "uvicorn", "uvloop",
"xgrammar>=0.1.6"] "xgrammar>=0.1.6"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6"]
......
...@@ -322,6 +322,18 @@ def throughput_test( ...@@ -322,6 +322,18 @@ def throughput_test(
) )
time.sleep(0.5) time.sleep(0.5)
try:
import os
import pwd
from gemlite.core import GemLiteLinearTriton
GemLiteLinearTriton.cache_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
except ImportError:
pass
logging.info("\nBenchmark...") logging.info("\nBenchmark...")
result = throughput_test_once( result = throughput_test_once(
backend_name=bench_args.backend, backend_name=bench_args.backend,
......
...@@ -385,6 +385,19 @@ def latency_test( ...@@ -385,6 +385,19 @@ def latency_test(
8, # shorter decoding to speed up the warmup 8, # shorter decoding to speed up the warmup
server_args.device, server_args.device,
) )
try:
import os
import pwd
from gemlite.core import GemLiteLinearTriton
GemLiteLinearTriton.cache_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
except ImportError:
pass
rank_print("Benchmark ...") rank_print("Benchmark ...")
# Run the sweep # Run the sweep
......
...@@ -47,6 +47,41 @@ def apply_torchao_config_to_model( ...@@ -47,6 +47,41 @@ def apply_torchao_config_to_model(
256, 256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
elif "gemlite" in torchao_config:
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
import os
import pwd
import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune
try:
from torchao.quantization import gemlite_uintx_weight_only
except:
print(
f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
)
return model
_quant_args = torchao_config.split("-")
bit_width = int(_quant_args[-2])
group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
try:
packing_bitwidth = int(_quant_args[-3])
except:
# if only 2 inputs found, use default value
packing_bitwidth = 32
quantize_(
model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
)
# try to load gemlite kernel config
GemLiteLinearTriton.load_config(
f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
)
elif "fp8wo" in torchao_config: elif "fp8wo" in torchao_config:
# this requires newer hardware # this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment