Unverified Commit 1c790c08 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[fix] remove unnecessary dp_size assert (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning
parent ffffc32d
import ctypes import ctypes
import random import random
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from types import MethodType from types import MethodType
...@@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1134,7 +1135,12 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." if self.dp_size == 1:
warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer( optimizer = HybridParallelZeroOptimizer(
optimizer, optimizer,
......
...@@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): ...@@ -118,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
@parameterize( @parameterize(
"test_args", "test_args",
[ [
{
"batch_size": 8,
"num_steps": 4,
"tp": 2,
"pp": 2,
"pp_style": "1f1b",
"num_model_chunks": 1,
"num_microbatches": 4,
"zero": 1,
"precision": "fp16",
"initial_scale": 1,
"max_length": 512,
"gradient_accumulation_step": 2,
},
{ {
"batch_size": 8, "batch_size": 8,
"num_steps": 4, "num_steps": 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment