"examples/vscode:/vscode.git/clone" did not exist on "070df689e627d07f28c8087ec85a4299c73145d9"
Unverified Commit 1c790c08 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[fix] remove unnecessary dp_size assert (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning
parent ffffc32d
import ctypes
import random
import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
......@@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group,
)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
......
......@@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
@parameterize(
"test_args",
[
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{
"batch_size": 8,
"num_steps": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment