Unverified Commit 73fdc8c5 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed zero3] need `generate(synced_gpus=True, ...)` (#22242)



* [deepspeed zero3] need generate(synced_gpus=True, ...)

* fix

* rework per Sylvain's suggestion

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8b05ace0
...@@ -2268,6 +2268,14 @@ rank1: ...@@ -2268,6 +2268,14 @@ rank1:
This was a very basic example and you will want to adapt it to your needs. This was a very basic example and you will want to adapt it to your needs.
### `generate` nuances
When using multiple GPUs with ZeRO Stage-3, one has to synchronize the GPUs by calling `generate(..., synced_gpus=True)`. If this is not done if one GPU finished generating before other GPUs the whole system will hang as the rest of the GPUs will not be able to received the shard of weights from the GPU that stopped generating.
Starting from `transformers>=4.28`, if `synced_gpus` isn't explicitly specified, it'll be set to `True` automatically if these conditions are detected. But you can still override the value of `synced_gpus` if need to.
## Testing Deepspeed Integration ## Testing Deepspeed Integration
If you submit a PR that involves DeepSpeed integration please note our CircleCI PR CI setup has no GPUs, so we only run tests requiring gpus on a different CI nightly. Therefore if you get a green CI report in your PR it doesn't mean DeepSpeed tests pass. If you submit a PR that involves DeepSpeed integration please note our CircleCI PR CI setup has no GPUs, so we only run tests requiring gpus on a different CI nightly. Therefore if you get a green CI report in your PR it doesn't mean DeepSpeed tests pass.
......
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from ..deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import ( from ..models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
...@@ -1114,7 +1115,7 @@ class GenerationMixin: ...@@ -1114,7 +1115,7 @@ class GenerationMixin:
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = False, synced_gpus: Optional[bool] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
r""" r"""
...@@ -1160,8 +1161,11 @@ class GenerationMixin: ...@@ -1160,8 +1161,11 @@ class GenerationMixin:
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904). Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`): synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
kwargs: kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
...@@ -1187,6 +1191,13 @@ class GenerationMixin: ...@@ -1187,6 +1191,13 @@ class GenerationMixin:
- [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`] - [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment