Unverified Commit 21aa5de0 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix
parent b3971044
import gc import gc
import logging import logging
import os import os
import random
from pathlib import Path from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple from typing import Callable, Iterator, List, Optional, Tuple
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
...@@ -11,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_default_group ...@@ -11,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_default_group
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
...@@ -448,6 +451,57 @@ class GeminiPlugin(DPPluginBase): ...@@ -448,6 +451,57 @@ class GeminiPlugin(DPPluginBase):
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda", "npu"] return ["cuda", "npu"]
def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
zero_world_size = self.pg_mesh.size(ZERO_AXIS)
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
sampler = DistributedSampler(
dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle
)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)
def configure( def configure(
self, self,
......
...@@ -72,6 +72,7 @@ def main(): ...@@ -72,6 +72,7 @@ def main():
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("--mbs", type=int, default=1) parser.add_argument("--mbs", type=int, default=1)
parser.add_argument("--zero", type=int, default=0) parser.add_argument("--zero", type=int, default=0)
...@@ -93,9 +94,11 @@ def main(): ...@@ -93,9 +94,11 @@ def main():
shard_param_frac=args.shard_param_frac, shard_param_frac=args.shard_param_frac,
offload_optim_frac=args.offload_optim_frac, offload_optim_frac=args.offload_optim_frac,
offload_param_frac=args.offload_param_frac, offload_param_frac=args.offload_param_frac,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio) plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp)
elif args.plugin == "fsdp": elif args.plugin == "fsdp":
if use_empty_init: if use_empty_init:
plugin = TorchFSDPPlugin( plugin = TorchFSDPPlugin(
......
...@@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss ...@@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss
config = transformers.GPTJConfig( config = transformers.GPTJConfig(
n_layer=2, n_layer=2,
n_head=16, n_head=4,
vocab_size=50258, vocab_size=50258,
attn_pdrop=0, attn_pdrop=0,
embd_pdrop=0, embd_pdrop=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment