Unverified Commit eb7ef267 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[quant] allow `components_to_quantize` to be a non-list for single components (#12234)



* allow non list components_to_quantize.

* up

* Apply suggestions from code review

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* [docs] components_to_quantize (#12287)

init
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent e1b7f1f2
...@@ -50,7 +50,7 @@ from diffusers.utils import export_to_video ...@@ -50,7 +50,7 @@ from diffusers.utils import export_to_video
pipeline_quant_config = PipelineQuantizationConfig( pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="torchao", quant_backend="torchao",
quant_kwargs={"quant_type": "int8wo"}, quant_kwargs={"quant_type": "int8wo"},
components_to_quantize=["transformer"] components_to_quantize="transformer"
) )
# fp8 layerwise weight-casting # fp8 layerwise weight-casting
......
...@@ -54,7 +54,7 @@ pipeline_quant_config = PipelineQuantizationConfig( ...@@ -54,7 +54,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16 "bnb_4bit_compute_dtype": torch.bfloat16
}, },
components_to_quantize=["transformer"] components_to_quantize="transformer"
) )
pipeline = HunyuanVideoPipeline.from_pretrained( pipeline = HunyuanVideoPipeline.from_pretrained(
...@@ -91,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig( ...@@ -91,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16 "bnb_4bit_compute_dtype": torch.bfloat16
}, },
components_to_quantize=["transformer"] components_to_quantize="transformer"
) )
pipeline = HunyuanVideoPipeline.from_pretrained( pipeline = HunyuanVideoPipeline.from_pretrained(
...@@ -139,7 +139,7 @@ export_to_video(video, "output.mp4", fps=15) ...@@ -139,7 +139,7 @@ export_to_video(video, "output.mp4", fps=15)
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16 "bnb_4bit_compute_dtype": torch.bfloat16
}, },
components_to_quantize=["transformer"] components_to_quantize="transformer"
) )
pipeline = HunyuanVideoPipeline.from_pretrained( pipeline = HunyuanVideoPipeline.from_pretrained(
......
...@@ -34,7 +34,9 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet ...@@ -34,7 +34,9 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet
> [!TIP] > [!TIP]
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend. > These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact. - `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
`components_to_quantize` accepts either a list for multiple models or a string for a single model.
The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`. The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
...@@ -62,6 +64,7 @@ pipe = DiffusionPipeline.from_pretrained( ...@@ -62,6 +64,7 @@ pipe = DiffusionPipeline.from_pretrained(
image = pipe("photo of a cute dog").images[0] image = pipe("photo of a cute dog").images[0]
``` ```
### Advanced quantization ### Advanced quantization
The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends. The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
......
...@@ -98,7 +98,7 @@ pipeline_quant_config = PipelineQuantizationConfig( ...@@ -98,7 +98,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
"bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16 "bnb_4bit_compute_dtype": torch.bfloat16
}, },
components_to_quantize=["transformer"] components_to_quantize="transformer"
) )
pipeline = HunyuanVideoPipeline.from_pretrained( pipeline = HunyuanVideoPipeline.from_pretrained(
......
...@@ -48,12 +48,15 @@ class PipelineQuantizationConfig: ...@@ -48,12 +48,15 @@ class PipelineQuantizationConfig:
self, self,
quant_backend: str = None, quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None, components_to_quantize: Optional[Union[List[str], str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
): ):
self.quant_backend = quant_backend self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults. # Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {} self.quant_kwargs = quant_kwargs or {}
if components_to_quantize:
if isinstance(components_to_quantize, str):
components_to_quantize = [components_to_quantize]
self.components_to_quantize = components_to_quantize self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping self.quant_mapping = quant_mapping
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}` self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
......
...@@ -299,3 +299,19 @@ transformer BitsAndBytesConfig { ...@@ -299,3 +299,19 @@ transformer BitsAndBytesConfig {
data = json.loads(json_part) data = json.loads(json_part)
return data return data
def test_single_component_to_quantize(self):
component_to_quantize = "transformer"
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_8bit",
quant_kwargs={"load_in_8bit": True},
components_to_quantize=component_to_quantize,
)
pipe = DiffusionPipeline.from_pretrained(
self.model_name,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
for name, component in pipe.components.items():
if name == component_to_quantize:
self.assertTrue(hasattr(component.config, "quantization_config"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment