Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
21aa5de0
Unverified
Commit
21aa5de0
authored
Dec 08, 2023
by
flybird11111
Committed by
GitHub
Dec 08, 2023
Browse files
[gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150)
* fix aaa fix fix fix * fix * fix * test ci * fix ci fix
parent
b3971044
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
2 deletions
+59
-2
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+54
-0
examples/language/llama2/benchmark.py
examples/language/llama2/benchmark.py
+4
-1
tests/kit/model_zoo/transformers/gptj.py
tests/kit/model_zoo/transformers/gptj.py
+1
-1
No files found.
colossalai/booster/plugin/gemini_plugin.py
View file @
21aa5de0
import
gc
import
gc
import
logging
import
logging
import
os
import
os
import
random
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -11,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_default_group
...
@@ -11,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_default_group
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
colossalai.checkpoint_io
import
CheckpointIndexFile
,
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io
import
CheckpointIndexFile
,
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io.utils
import
(
from
colossalai.checkpoint_io.utils
import
(
...
@@ -449,6 +452,57 @@ class GeminiPlugin(DPPluginBase):
...
@@ -449,6 +452,57 @@ class GeminiPlugin(DPPluginBase):
def
supported_devices
(
self
)
->
List
[
str
]:
def
supported_devices
(
self
)
->
List
[
str
]:
return
[
"cuda"
,
"npu"
]
return
[
"cuda"
,
"npu"
]
def
prepare_dataloader
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
seed
=
1024
,
drop_last
=
False
,
pin_memory
=
False
,
num_workers
=
0
,
**
kwargs
):
r
"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs
=
kwargs
.
copy
()
zero_world_size
=
self
.
pg_mesh
.
size
(
ZERO_AXIS
)
extra_dp_world_size
=
self
.
pg_mesh
.
size
(
DP_AXIS
)
zero_rank
=
self
.
pg_mesh
.
coordinate
(
ZERO_AXIS
)
extra_dp_rank
=
self
.
pg_mesh
.
coordinate
(
DP_AXIS
)
sampler
=
DistributedSampler
(
dataset
,
num_replicas
=
zero_world_size
*
extra_dp_world_size
,
rank
=
zero_rank
*
extra_dp_world_size
+
extra_dp_rank
,
shuffle
=
shuffle
)
# Deterministic dataloader
def
seed_worker
(
worker_id
):
worker_seed
=
seed
np
.
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
return
DataLoader
(
dataset
,
batch_size
=
batch_size
,
sampler
=
sampler
,
worker_init_fn
=
seed_worker
,
drop_last
=
drop_last
,
pin_memory
=
pin_memory
,
num_workers
=
num_workers
,
**
_kwargs
,
)
def
configure
(
def
configure
(
self
,
self
,
model
:
nn
.
Module
,
model
:
nn
.
Module
,
...
...
examples/language/llama2/benchmark.py
View file @
21aa5de0
...
@@ -72,6 +72,7 @@ def main():
...
@@ -72,6 +72,7 @@ def main():
parser
.
add_argument
(
"--offload_optim_frac"
,
type
=
float
,
default
=
0.0
,
help
=
"Offload optim fraction. Only for gemini"
)
parser
.
add_argument
(
"--offload_optim_frac"
,
type
=
float
,
default
=
0.0
,
help
=
"Offload optim fraction. Only for gemini"
)
parser
.
add_argument
(
"--offload_param_frac"
,
type
=
float
,
default
=
0.0
,
help
=
"Offload param fraction. Only for gemini"
)
parser
.
add_argument
(
"--offload_param_frac"
,
type
=
float
,
default
=
0.0
,
help
=
"Offload param fraction. Only for gemini"
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
,
help
=
"Tensor parallel size"
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
,
help
=
"Tensor parallel size"
)
parser
.
add_argument
(
"--extra_dp"
,
type
=
int
,
default
=
1
,
help
=
"Extra data parallel size, used for Gemini"
)
parser
.
add_argument
(
"--pp"
,
type
=
int
,
default
=
1
,
help
=
"Pipeline parallel size"
)
parser
.
add_argument
(
"--pp"
,
type
=
int
,
default
=
1
,
help
=
"Pipeline parallel size"
)
parser
.
add_argument
(
"--mbs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--mbs"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
0
)
...
@@ -93,9 +94,11 @@ def main():
...
@@ -93,9 +94,11 @@ def main():
shard_param_frac
=
args
.
shard_param_frac
,
shard_param_frac
=
args
.
shard_param_frac
,
offload_optim_frac
=
args
.
offload_optim_frac
,
offload_optim_frac
=
args
.
offload_optim_frac
,
offload_param_frac
=
args
.
offload_param_frac
,
offload_param_frac
=
args
.
offload_param_frac
,
tp_size
=
args
.
tp
,
extra_dp_size
=
args
.
extra_dp
,
)
)
elif
args
.
plugin
==
"gemini_auto"
:
elif
args
.
plugin
==
"gemini_auto"
:
plugin
=
GeminiPlugin
(
placement_policy
=
"auto"
,
precision
=
"bf16"
,
warmup_non_model_data_ratio
=
args
.
warmup_ratio
)
plugin
=
GeminiPlugin
(
placement_policy
=
"auto"
,
precision
=
"bf16"
,
warmup_non_model_data_ratio
=
args
.
warmup_ratio
,
tp_size
=
args
.
tp
,
extra_dp_size
=
args
.
extra_dp
)
elif
args
.
plugin
==
"fsdp"
:
elif
args
.
plugin
==
"fsdp"
:
if
use_empty_init
:
if
use_empty_init
:
plugin
=
TorchFSDPPlugin
(
plugin
=
TorchFSDPPlugin
(
...
...
tests/kit/model_zoo/transformers/gptj.py
View file @
21aa5de0
...
@@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss
...
@@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss
config
=
transformers
.
GPTJConfig
(
config
=
transformers
.
GPTJConfig
(
n_layer
=
2
,
n_layer
=
2
,
n_head
=
16
,
n_head
=
4
,
vocab_size
=
50258
,
vocab_size
=
50258
,
attn_pdrop
=
0
,
attn_pdrop
=
0
,
embd_pdrop
=
0
,
embd_pdrop
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment