distributed_inference.md 8.85 KB
Newer Older
Aryan's avatar
Aryan committed
1
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

Steven Liu's avatar
Steven Liu committed
13
# Distributed inference
14

15
Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.
16

17
This guide will show you how to use [Accelerate](https://huggingface.co/docs/accelerate/index) and [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) for distributed inference.
18

19
## Accelerate
20

21
Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.
22

23
Install Accelerate with the following command.
24

25
26
27
28
29
30
31
```bash
uv pip install accelerate
```

Initialize a [`accelerate.PartialState`] class in a Python file to create a distributed environment. The [`accelerate.PartialState`] class manages process management, device control and distribution, and process coordination.

Move the [`DiffusionPipeline`] to [`accelerate.PartialState.device`] to assign a GPU to each process.
32
33

```py
34
import torch
35
36
37
from accelerate import PartialState
from diffusers import DiffusionPipeline

38
pipeline = DiffusionPipeline.from_pretrained(
39
    "Qwen/Qwen-Image", torch_dtype=torch.float16
40
)
41
42
distributed_state = PartialState()
pipeline.to(distributed_state.device)
43
44
45
```

Use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
46

47
```py
48
49
50
51
52
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
    result = pipeline(prompt).images[0]
    result.save(f"result_{distributed_state.process_index}.png")
```

53
Call `accelerate launch` to run the script and use the `--num_processes` argument to set the number of GPUs to use.
54
55
56
57
58
59
60

```bash
accelerate launch run_distributed.py --num_processes=2
```

## PyTorch Distributed

61
PyTorch [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) enables [data parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism), which replicates the same model on each device, to process different batches of data in parallel.
62

63
Import `torch.distributed` and `torch.multiprocessing` into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.
64
65
66
67
68
69
70
71

```py
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from diffusers import DiffusionPipeline

72
73
pipeline = DiffusionPipeline.from_pretrained(
    "Qwen/Qwen-Image", torch_dtype=torch.float16,
74
)
75
76
```

77
Create a function for inference with [init_process_group](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group). This method creates a distributed environment with the backend type, the `rank` of the current process, and the `world_size` or number of processes participating (for example, 2 GPUs would be `world_size=2`).
78

79
Move the pipeline to `rank` and use `get_rank` to assign a GPU to each process. Each process handles a different prompt.
80
81
82
83
84

```py
def run_inference(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

85
    pipeline.to(rank)
86
87
88
89
90
91
92
93
94
95

    if torch.distributed.get_rank() == 0:
        prompt = "a dog"
    elif torch.distributed.get_rank() == 1:
        prompt = "a cat"

    image = sd(prompt).images[0]
    image.save(f"./{'_'.join(prompt)}.png")
```

96
Use [mp.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to create the number of processes defined in `world_size`.
97
98
99
100
101
102
103
104
105
106
107

```py
def main():
    world_size = 2
    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()
```

108
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
109
110
111

```bash
torchrun run_distributed.py --nproc_per_node=2
112
```
113

114
## device_map
Steven Liu's avatar
Steven Liu committed
115

116
The `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).
Steven Liu's avatar
Steven Liu committed
117

118
Set `device_map="balanced"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.
Steven Liu's avatar
Steven Liu committed
119
120
121
122
123

```py
from diffusers import FluxPipeline
import torch

124
125
126
127
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
Steven Liu's avatar
Steven Liu committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=None,
    vae=None,
    device_map="balanced",
    max_memory={0: "16GB", 1: "16GB"},
    torch_dtype=torch.bfloat16
)
with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt, prompt_2=None, max_sequence_length=512
    )
```

144
After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
Steven Liu's avatar
Steven Liu committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

```py
import gc 

def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline

flush()
```

164
Set `device_map="auto"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
Steven Liu's avatar
Steven Liu committed
165
166

```py
167
from diffusers import AutoModel
Steven Liu's avatar
Steven Liu committed
168
169
import torch 

170
transformer = AutoModel.from_pretrained(
Steven Liu's avatar
Steven Liu committed
171
172
173
174
175
176
177
178
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    device_map="auto",
    torch_dtype=torch.bfloat16
)
```

> [!TIP]
179
> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.
Steven Liu's avatar
Steven Liu committed
180

181
Add the transformer model to the pipeline and set the `output_type="latent"` to generate the latents.
Steven Liu's avatar
Steven Liu committed
182
183
184

```py
pipeline = FluxPipeline.from_pretrained(
185
    "black-forest-labs/FLUX.1-dev",
Steven Liu's avatar
Steven Liu committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    vae=None,
    transformer=transformer,
    torch_dtype=torch.bfloat16
)

print("Running denoising.")
height, width = 768, 1360
latents = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=50,
    guidance_scale=3.5,
    height=height,
    width=width,
    output_type="latent",
).images
```

208
Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.
Steven Liu's avatar
Steven Liu committed
209
210

```py
211
import torch
Steven Liu's avatar
Steven Liu committed
212
213
214
215
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor

vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
216
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
Steven Liu's avatar
Steven Liu committed
217
218
219
220
221
222
223
224
225
226
227
228
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

with torch.no_grad():
    print("Running decoding.")
    latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor

    image = vae.decode(latents, return_dict=False)[0]
    image = image_processor.postprocess(image, output_type="pil")
    image[0].save("split_transformer.png")
```

229
230
231
232
233
## Resources

- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.
- For more details, check out Accelerate's [Distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
- The `device_map` argument assign models or an entire pipeline to devices. Refer to the [device placement](../using-diffusers/loading#device-placement) docs for more information.