Unverified Commit 8dc8a99b authored by BADAOUI Abdennacer's avatar BADAOUI Abdennacer Committed by GitHub
Browse files

[ROCm] Enable bitsandbytes quantization support on ROCm (#34688)


Signed-off-by: default avatarbadaoui <abdennacerbadaoui0@gmail.com>
parent 2aab2bb5
...@@ -7,7 +7,7 @@ Compared to other quantization methods, BitsAndBytes eliminates the need for cal ...@@ -7,7 +7,7 @@ Compared to other quantization methods, BitsAndBytes eliminates the need for cal
Below are the steps to utilize BitsAndBytes with vLLM. Below are the steps to utilize BitsAndBytes with vLLM.
```bash ```bash
pip install bitsandbytes>=0.46.1 pip install bitsandbytes>=0.49.2
``` ```
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
......
...@@ -33,7 +33,7 @@ transformers==4.57.5 ...@@ -33,7 +33,7 @@ transformers==4.57.5
tokenizers==0.22.0 tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test. schemathesis>=3.39.15 # Required for openai schema test.
# quantization # quantization
bitsandbytes>=0.46.1 bitsandbytes>=0.49.2
buildkite-test-collector==0.1.9 buildkite-test-collector==0.1.9
......
...@@ -102,3 +102,5 @@ terratorch==1.2.2 ...@@ -102,3 +102,5 @@ terratorch==1.2.2
segmentation-models-pytorch==0.5.0 segmentation-models-pytorch==0.5.0
# Required for Prithvi tests # Required for Prithvi tests
imagehash==4.3.2 imagehash==4.3.2
# Required for bitsandbytes quantization test
bitsandbytes==0.49.2
...@@ -41,7 +41,7 @@ transformers==4.57.5 ...@@ -41,7 +41,7 @@ transformers==4.57.5
tokenizers==0.22.0 tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test. schemathesis>=3.39.15 # Required for openai schema test.
# quantization # quantization
bitsandbytes==0.46.1 bitsandbytes==0.49.2
buildkite-test-collector==0.1.9 buildkite-test-collector==0.1.9
......
...@@ -66,7 +66,7 @@ backoff==2.2.1 ...@@ -66,7 +66,7 @@ backoff==2.2.1
# via # via
# -r requirements/test.in # -r requirements/test.in
# schemathesis # schemathesis
bitsandbytes==0.46.1 bitsandbytes==0.49.2
# via # via
# -r requirements/test.in # -r requirements/test.in
# lightning # lightning
...@@ -653,6 +653,7 @@ orjson==3.11.5 ...@@ -653,6 +653,7 @@ orjson==3.11.5
packaging==24.2 packaging==24.2
# via # via
# accelerate # accelerate
# bitsandbytes
# black # black
# datamodel-code-generator # datamodel-code-generator
# datasets # datasets
......
...@@ -6,8 +6,6 @@ from typing import Any ...@@ -6,8 +6,6 @@ from typing import Any
import pytest import pytest
from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts from ..utils import multi_gpu_test, prep_prompts
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
...@@ -131,6 +129,7 @@ def test_distributed( ...@@ -131,6 +129,7 @@ def test_distributed(
"quantization": "bitsandbytes", "quantization": "bitsandbytes",
}, },
), ),
("unsloth/tinyllama-bnb-4bit", {}),
], ],
) )
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
...@@ -143,12 +142,6 @@ def test_quantization( ...@@ -143,12 +142,6 @@ def test_quantization(
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> None: ) -> None:
if (
current_platform.is_rocm()
and quantization_kwargs.get("quantization", "") == "bitsandbytes"
):
pytest.skip("bitsandbytes quantization is currently not supported in rocm.")
with vllm_runner( with vllm_runner(
model, model,
model_impl="auto", model_impl="auto",
......
...@@ -28,6 +28,24 @@ from vllm.platforms import current_platform ...@@ -28,6 +28,24 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
def _check_bitsandbytes_version():
min_version = "0.49.2" if current_platform.is_rocm() else "0.48.1"
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse(min_version):
raise ImportError(
"bitsandbytes version is wrong. Please "
f"install bitsandbytes>={min_version}."
)
except ImportError as err:
raise ImportError(
f"Please install bitsandbytes>={min_version} via "
f"`pip install bitsandbytes>={min_version}` to use "
"bitsandbytes quantizer."
) from err
class BitsAndBytesConfig(QuantizationConfig): class BitsAndBytesConfig(QuantizationConfig):
"""Config class for BitsAndBytes Quantization. """Config class for BitsAndBytes Quantization.
...@@ -183,21 +201,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase): ...@@ -183,21 +201,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
""" """
def __init__(self, quant_config: BitsAndBytesConfig): def __init__(self, quant_config: BitsAndBytesConfig):
try: _check_bitsandbytes_version()
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
...@@ -442,20 +446,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): ...@@ -442,20 +446,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
moe: FusedMoEConfig, moe: FusedMoEConfig,
): ):
super().__init__(moe) super().__init__(moe)
try: _check_bitsandbytes_version()
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
......
...@@ -244,10 +244,8 @@ class RocmPlatform(Platform): ...@@ -244,10 +244,8 @@ class RocmPlatform(Platform):
"mxfp4", "mxfp4",
"petit_nvfp4", "petit_nvfp4",
"torchao", "torchao",
"bitsandbytes",
] ]
# bitsandbytes not supported on gfx9 (warp size 64 limitation)
if not on_gfx9():
supported_quantization += ["bitsandbytes"]
@classmethod @classmethod
def import_kernels(cls) -> None: def import_kernels(cls) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment