"...resnet50_tensorflow.git" did not exist on "424fe9f6d544e0e7d4f25103b5088caf68d0ff34"
Unverified Commit 32290d87 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Deepspeed] various fixes (#12058)

* replace deprecated config

* sub_group_size was too big

* complete deprecation removal
parent fd690283
...@@ -238,17 +238,20 @@ with DeepSpeed is to have at least the following configuration in the configurat ...@@ -238,17 +238,20 @@ with DeepSpeed is to have at least the following configuration in the configurat
{ {
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true, "allgather_partitions": true,
"allgather_bucket_size": 2e8, "allgather_bucket_size": 2e8,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 2e8, "reduce_bucket_size": 2e8,
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true
"cpu_offload": true
} }
} }
which enables ``cpu_offload`` and some other important features. You may experiment with the buffer sizes, you will which enables optimizer offload and some other important features. You may experiment with the buffer sizes, you will
find more details in the discussion below. find more details in the discussion below.
For a practical usage example of this type of deployment, please, see this `post For a practical usage example of this type of deployment, please, see this `post
...@@ -352,7 +355,7 @@ cell with: ...@@ -352,7 +355,7 @@ cell with:
}, },
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 1e14, "sub_group_size": 1e9,
"reduce_bucket_size": "auto", "reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto", "stage3_param_persistence_threshold": "auto",
...@@ -463,13 +466,16 @@ precision training if ``--fp16`` is passed: ...@@ -463,13 +466,16 @@ precision training if ``--fp16`` is passed:
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true, "allgather_partitions": true,
"allgather_bucket_size": 2e8, "allgather_bucket_size": 2e8,
"overlap_comm": true, "overlap_comm": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 2e8, "reduce_bucket_size": 2e8,
"contiguous_gradients": true, "contiguous_gradients": true
"cpu_offload": true
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
...@@ -582,19 +588,22 @@ The following is an example configuration for ZeRO stage 2: ...@@ -582,19 +588,22 @@ The following is an example configuration for ZeRO stage 2:
{ {
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true, "allgather_partitions": true,
"allgather_bucket_size": 5e8, "allgather_bucket_size": 5e8,
"overlap_comm": true, "overlap_comm": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 5e8, "reduce_bucket_size": 5e8,
"contiguous_gradients": true, "contiguous_gradients": true
"cpu_offload": true
} }
} }
**Performance tuning:** **Performance tuning:**
- enabling ``cpu_offload`` should reduce GPU RAM usage (it requires ``"stage": 2``) - enabling ``offload_optimizer`` should reduce GPU RAM usage (it requires ``"stage": 2``)
- ``"overlap_comm": true`` trades off increased GPU RAM usage to lower all-reduce latency. ``overlap_comm`` uses 4.5x - ``"overlap_comm": true`` trades off increased GPU RAM usage to lower all-reduce latency. ``overlap_comm`` uses 4.5x
the ``allgather_bucket_size`` and ``reduce_bucket_size`` values. So if they are set to 5e8, this requires a 9GB the ``allgather_bucket_size`` and ``reduce_bucket_size`` values. So if they are set to 5e8, this requires a 9GB
footprint (``5e8 x 2Bytes x 2 x 4.5``). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting footprint (``5e8 x 2Bytes x 2 x 4.5``). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting
...@@ -628,7 +637,7 @@ The following is an example configuration for ZeRO stage 3: ...@@ -628,7 +637,7 @@ The following is an example configuration for ZeRO stage 3:
}, },
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 1e14, "sub_group_size": 1e9,
"reduce_bucket_size": "auto", "reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto", "stage3_param_persistence_threshold": "auto",
...@@ -649,7 +658,6 @@ and its typically accessed much faster than normal CPU memory. ...@@ -649,7 +658,6 @@ and its typically accessed much faster than normal CPU memory.
**Performance tuning:** **Performance tuning:**
- ``sub_group_size``: ``1e14``
- ``stage3_max_live_parameters``: ``1e9`` - ``stage3_max_live_parameters``: ``1e9``
- ``stage3_max_reuse_distance``: ``1e9`` - ``stage3_max_reuse_distance``: ``1e9``
...@@ -680,8 +688,11 @@ flexible. ...@@ -680,8 +688,11 @@ flexible.
If you're migrating from ZeRO-2 configuration note that ``allgather_partitions``, ``allgather_bucket_size`` and If you're migrating from ZeRO-2 configuration note that ``allgather_partitions``, ``allgather_bucket_size`` and
``reduce_scatter`` configuration parameters are not used in ZeRO-3. If you keep these in the config file they will just ``reduce_scatter`` configuration parameters are not used in ZeRO-3. If you keep these in the config file they will just
be ignored. Make sure to remove ``cpu_offload`` though, since it has been deprecated in ZeRO-3. be ignored.
- ``sub_group_size``: ``1e9``
This one does impact GPU memory usage. But no docs at the moment on Deepspeed side to explain the tuning.
.. _deepspeed-nvme: .. _deepspeed-nvme:
...@@ -725,7 +736,7 @@ The following configuration example enables NVMe to offload both optimizer state ...@@ -725,7 +736,7 @@ The following configuration example enables NVMe to offload both optimizer state
} }
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 1e14, "sub_group_size": 1e9,
"reduce_bucket_size": "auto", "reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto", "stage3_param_persistence_threshold": "auto",
...@@ -766,9 +777,9 @@ It's possible to adjust ZeRO-3 configuration to make it perform closer to ZeRO-2 ...@@ -766,9 +777,9 @@ It's possible to adjust ZeRO-3 configuration to make it perform closer to ZeRO-2
- set ``stage3_param_persistence_threshold`` to a very large number - larger than the largest parameter, e.g., ``6 * - set ``stage3_param_persistence_threshold`` to a very large number - larger than the largest parameter, e.g., ``6 *
hidden_size * hidden_size``. This will keep the parameters on the GPUs. hidden_size * hidden_size``. This will keep the parameters on the GPUs.
- turn off ``cpu_offload_params`` since ZeRO-2 doesn't have that option. - turn off ``offload_params`` since ZeRO-2 doesn't have that option.
The performance will likely improve significantly with just ``cpu_offload_params`` turned off, even if you don't change The performance will likely improve significantly with just ``offload_params`` turned off, even if you don't change
``stage3_param_persistence_threshold``. Of course, these changes will impact the size of the model you can train. So ``stage3_param_persistence_threshold``. Of course, these changes will impact the size of the model you can train. So
these help you to trade scalability for speed depending on your needs. these help you to trade scalability for speed depending on your needs.
...@@ -814,13 +825,16 @@ Here is a full ZeRO-2 auto-configuration file ``ds_config_zero2.json``: ...@@ -814,13 +825,16 @@ Here is a full ZeRO-2 auto-configuration file ``ds_config_zero2.json``:
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true, "allgather_partitions": true,
"allgather_bucket_size": 2e8, "allgather_bucket_size": 2e8,
"overlap_comm": true, "overlap_comm": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 2e8, "reduce_bucket_size": 2e8,
"contiguous_gradients": true, "contiguous_gradients": true
"cpu_offload": true
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
...@@ -868,13 +882,16 @@ values look like, but we highly recommend using the one with multiple ``auto`` s ...@@ -868,13 +882,16 @@ values look like, but we highly recommend using the one with multiple ``auto`` s
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true, "allgather_partitions": true,
"allgather_bucket_size": 2e8, "allgather_bucket_size": 2e8,
"overlap_comm": true, "overlap_comm": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 2e8, "reduce_bucket_size": 2e8,
"contiguous_gradients": true, "contiguous_gradients": true
"cpu_offload": true
}, },
"steps_per_print": 2000, "steps_per_print": 2000,
...@@ -934,7 +951,7 @@ Here is a full ZeRO-3 auto-configuration file ``ds_config_zero3.json``: ...@@ -934,7 +951,7 @@ Here is a full ZeRO-3 auto-configuration file ``ds_config_zero3.json``:
}, },
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 1e14, "sub_group_size": 1e9,
"reduce_bucket_size": "auto", "reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto", "stage3_param_persistence_threshold": "auto",
...@@ -997,7 +1014,7 @@ values look like, but we highly recommend using the one with multiple ``auto`` s ...@@ -997,7 +1014,7 @@ values look like, but we highly recommend using the one with multiple ``auto`` s
}, },
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 1e14, "sub_group_size": 1e9,
"reduce_bucket_size": 1e6, "reduce_bucket_size": 1e6,
"stage3_prefetch_bucket_size": 0.94e6, "stage3_prefetch_bucket_size": 0.94e6,
"stage3_param_persistence_threshold": 1e4, "stage3_param_persistence_threshold": 1e4,
...@@ -1014,8 +1031,8 @@ values look like, but we highly recommend using the one with multiple ``auto`` s ...@@ -1014,8 +1031,8 @@ values look like, but we highly recommend using the one with multiple ``auto`` s
Optimizer and Scheduler Optimizer and Scheduler
======================================================================================================================= =======================================================================================================================
As long as you don't enable ``cpu_offload`` you can mix and match DeepSpeed and HuggingFace schedulers and optimizers, As long as you don't enable ``offload_optimizer`` you can mix and match DeepSpeed and HuggingFace schedulers and
with the exception of using the combination of HuggingFace scheduler and DeepSpeed optimizer: optimizers, with the exception of using the combination of HuggingFace scheduler and DeepSpeed optimizer:
+--------------+--------------+--------------+ +--------------+--------------+--------------+
| Combos | HF Scheduler | DS Scheduler | | Combos | HF Scheduler | DS Scheduler |
...@@ -1025,7 +1042,7 @@ with the exception of using the combination of HuggingFace scheduler and DeepSpe ...@@ -1025,7 +1042,7 @@ with the exception of using the combination of HuggingFace scheduler and DeepSpe
| DS Optimizer | No | Yes | | DS Optimizer | No | Yes |
+--------------+--------------+--------------+ +--------------+--------------+--------------+
If ``cpu_offload`` is enabled you must use both DeepSpeed scheduler and DeepSpeed optimizer. If ``offload_optimizer`` is enabled you must use both DeepSpeed scheduler and DeepSpeed optimizer.
...@@ -1546,8 +1563,8 @@ Troubleshooting ...@@ -1546,8 +1563,8 @@ Troubleshooting
If the ``deepspeed`` process gets killed at launch time without a traceback, that usually means that the program tried If the ``deepspeed`` process gets killed at launch time without a traceback, that usually means that the program tried
to allocate more CPU memory than your system has or your process is allowed to allocate and the OS kernel killed that to allocate more CPU memory than your system has or your process is allowed to allocate and the OS kernel killed that
process. This is because your configuration file most likely has either ``offload_optimizer`` or ``offload_param`` or process. This is because your configuration file most likely has either ``offload_optimizer`` or ``offload_param`` or
both configured to offload to ``cpu`` (or under ZeRO-2 ``cpu_offload`` is enabled). If you have NVMe, experiment with both configured to offload to ``cpu``. If you have NVMe, experiment with offloading to NVMe if you're running under
offloading to NVMe if you're running under ZeRO-3. ZeRO-3.
Work is being done to enable estimating how much memory is needed for a specific model: `PR Work is being done to enable estimating how much memory is needed for a specific model: `PR
<https://github.com/microsoft/DeepSpeed/pull/965>`__. <https://github.com/microsoft/DeepSpeed/pull/965>`__.
......
...@@ -76,9 +76,7 @@ class HfDeepSpeedConfig: ...@@ -76,9 +76,7 @@ class HfDeepSpeedConfig:
# offload # offload
self.offload = False self.offload = False
config_zero = config.get("zero_optimization", {}) config_zero = config.get("zero_optimization", {})
if self.is_zero2(): if self.is_zero2() or self.is_zero3():
self.offload = self.is_true(config_zero, "cpu_offload")
elif self.is_zero3():
offload_devices = ["cpu", "nvme"] offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices: if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
self.offload = True self.offload = True
......
...@@ -29,13 +29,16 @@ ...@@ -29,13 +29,16 @@
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true, "allgather_partitions": true,
"allgather_bucket_size": 2e8, "allgather_bucket_size": 2e8,
"overlap_comm": true, "overlap_comm": true,
"reduce_scatter": true, "reduce_scatter": true,
"reduce_bucket_size": 2e8, "reduce_bucket_size": 2e8,
"contiguous_gradients": true, "contiguous_gradients": true
"cpu_offload": true
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
}, },
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 1e14, "sub_group_size": 1e9,
"reduce_bucket_size": "auto", "reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto", "stage3_param_persistence_threshold": "auto",
......
...@@ -269,7 +269,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -269,7 +269,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
ds_config_zero2_dict = self.get_config_dict(ZERO2) ds_config_zero2_dict = self.get_config_dict(ZERO2)
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
trainer.train() trainer.train()
...@@ -281,7 +281,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -281,7 +281,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero2_dict = self.get_config_dict(ZERO2) ds_config_zero2_dict = self.get_config_dict(ZERO2)
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
trainer.train() trainer.train()
...@@ -293,7 +293,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -293,7 +293,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero2_dict = self.get_config_dict(ZERO2) ds_config_zero2_dict = self.get_config_dict(ZERO2)
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
...@@ -326,9 +326,6 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -326,9 +326,6 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
ds_config_dict = self.get_config_dict(stage) ds_config_dict = self.get_config_dict(stage)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer del ds_config_dict["optimizer"] # force default HF Trainer optimizer
# force cpu offload # force cpu offload
if stage == "stage2":
ds_config_dict["zero_optimization"]["cpu_offload"] = True
elif stage == "stage3":
ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu" ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu"
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_dict) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment