Unverified Commit ba2ba901 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Add cuda kernel support for GGUF inference (#11869)



* add gguf kernel support
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>

* fix
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>

* optimize
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>

* update

* update

* update

* update

* update

---------
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarDN6 <dhruv.nair@gmail.com>
parent fa4c0e5e
...@@ -333,7 +333,7 @@ jobs: ...@@ -333,7 +333,7 @@ jobs:
additional_deps: ["peft"] additional_deps: ["peft"]
- backend: "gguf" - backend: "gguf"
test_location: "gguf" test_location: "gguf"
additional_deps: ["peft"] additional_deps: ["peft", "kernels"]
- backend: "torchao" - backend: "torchao"
test_location: "torchao" test_location: "torchao"
additional_deps: [] additional_deps: []
......
...@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0] ...@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png") image.save("flux-gguf.png")
``` ```
## Using Optimized CUDA Kernels with GGUF
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
```shell
pip install -U kernels
```
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
## Supported Quantization Types ## Supported Quantization Types
- BF16 - BF16
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and # # See the License for the specific language governing permissions and
# # limitations under the License. # # limitations under the License.
import inspect import inspect
import os
from contextlib import nullcontext from contextlib import nullcontext
import gguf import gguf
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import is_accelerate_available from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available(): if is_accelerate_available():
...@@ -29,6 +29,82 @@ if is_accelerate_available(): ...@@ -29,6 +29,82 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook): def _create_accelerate_new_hook(old_hook):
r""" r"""
...@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear): ...@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
) -> None: ) -> None:
super().__init__(in_features, out_features, bias, device) super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward(self, inputs): def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight) weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype) weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias) output = torch.nn.functional.linear(inputs, weight, bias)
return output return output
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output
...@@ -81,6 +81,7 @@ from .import_utils import ( ...@@ -81,6 +81,7 @@ from .import_utils import (
is_invisible_watermark_available, is_invisible_watermark_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version, is_k_diffusion_version,
is_kernels_available,
is_librosa_available, is_librosa_available,
is_matplotlib_available, is_matplotlib_available,
is_nltk_available, is_nltk_available,
......
...@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") ...@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers") _transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect") _inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode") _unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") _k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
...@@ -277,6 +278,10 @@ def is_accelerate_available(): ...@@ -277,6 +278,10 @@ def is_accelerate_available():
return _accelerate_available return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available(): def is_k_diffusion_available():
return _k_diffusion_available return _k_diffusion_available
......
...@@ -36,6 +36,7 @@ from .import_utils import ( ...@@ -36,6 +36,7 @@ from .import_utils import (
is_compel_available, is_compel_available,
is_flax_available, is_flax_available,
is_gguf_available, is_gguf_available,
is_kernels_available,
is_note_seq_available, is_note_seq_available,
is_onnx_available, is_onnx_available,
is_opencv_available, is_opencv_available,
...@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version): ...@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case): def deprecate_after_peft_backend(test_case):
""" """
Decorator marking a test that will be skipped after PEFT backend Decorator marking a test that will be skipped after PEFT backend
......
...@@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import ( ...@@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import (
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_accelerate, require_accelerate,
require_accelerator,
require_big_accelerator, require_big_accelerator,
require_gguf_version_greater_or_equal, require_gguf_version_greater_or_equal,
require_kernels_version_greater_or_equal,
require_peft_backend, require_peft_backend,
require_torch_version_greater, require_torch_version_greater,
torch_device, torch_device,
...@@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests ...@@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available(): if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism() enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
@require_kernels_version_greater_or_equal("0.9.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
in_features, out_features = 512, 512
torch.manual_seed(42)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly @nightly
@require_big_accelerator @require_big_accelerator
@require_accelerate @require_accelerate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment