"examples/vscode:/vscode.git/clone" did not exist on "ff3413be3267cf78643ee9d476bf532318cb8e1c"
Unverified Commit 6c0fa7b9 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[llama] fix dataloader for hybrid parallel (#5358)

* [plugin] refactor prepare dataloader

* [plugin] update train script
parent 2dd01e3a
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
import os import os
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable from typing import Dict, Iterator, List, Optional, Sequence, Union
import torch import torch
from datasets import dataset_dict, load_from_disk import torch.nn.functional as F
from datasets import Dataset as HFDataset from datasets import Dataset as HFDataset
from torch.distributed import ProcessGroup from datasets import dataset_dict, load_from_disk
from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
import torch.nn.functional as F
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike] PathType = Union[str, os.PathLike]
...@@ -171,49 +167,3 @@ class StatefulDistributedSampler(DistributedSampler): ...@@ -171,49 +167,3 @@ class StatefulDistributedSampler(DistributedSampler):
def set_start_index(self, start_index: int) -> None: def set_start_index(self, start_index: int) -> None:
self.start_index = start_index self.start_index = start_index
def setup_distributed_dataloader(
dataset: DatasetType,
batch_size: int = 1,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
process_group: Optional[ProcessGroup] = None,
**kwargs,
) -> DataLoader:
"""
Setup dataloader for distributed training.
"""
_kwargs = kwargs.copy()
process_group = process_group or _get_default_group()
sampler = StatefulDistributedSampler(
dataset=dataset,
num_replicas=process_group.size(),
rank=process_group.rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)
# Deterministic dataloader
def seed_worker(worker_id: int) -> None:
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
worker_init_fn=seed_worker,
**_kwargs,
)
...@@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import ( ...@@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset, DataCollatorForSupervisedDataset,
StatefulDistributedSampler, StatefulDistributedSampler,
load_tokenized_dataset, load_tokenized_dataset,
setup_distributed_dataloader,
) )
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
...@@ -194,12 +193,13 @@ def main() -> None: ...@@ -194,12 +193,13 @@ def main() -> None:
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
dataloader = setup_distributed_dataloader( dataloader = plugin.prepare_dataloader(
dataset=dataset, dataset=dataset,
batch_size=args.micro_batch_size, batch_size=args.micro_batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
) )
coordinator.print_on_master( coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
......
...@@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import ( ...@@ -16,7 +16,6 @@ from colossal_llama2.dataset.loader import (
DataCollatorForSupervisedDataset, DataCollatorForSupervisedDataset,
StatefulDistributedSampler, StatefulDistributedSampler,
load_tokenized_dataset, load_tokenized_dataset,
setup_distributed_dataloader,
) )
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
...@@ -203,12 +202,13 @@ def main() -> None: ...@@ -203,12 +202,13 @@ def main() -> None:
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
dataloader = setup_distributed_dataloader( dataloader = plugin.prepare_dataloader(
dataset=dataset, dataset=dataset,
batch_size=args.micro_batch_size, batch_size=args.micro_batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=data_collator, collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
) )
coordinator.print_on_master( coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
......
...@@ -21,7 +21,16 @@ class DPPluginBase(Plugin): ...@@ -21,7 +21,16 @@ class DPPluginBase(Plugin):
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
def prepare_dataloader( def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
): ):
r""" r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by Prepare a dataloader for distributed training. The dataloader will be wrapped by
...@@ -45,7 +54,8 @@ class DPPluginBase(Plugin): ...@@ -45,7 +54,8 @@ class DPPluginBase(Plugin):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
""" """
_kwargs = kwargs.copy() _kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader # Deterministic dataloader
def seed_worker(worker_id): def seed_worker(worker_id):
......
...@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase): ...@@ -456,7 +456,16 @@ class GeminiPlugin(DPPluginBase):
return ["cuda", "npu"] return ["cuda", "npu"]
def prepare_dataloader( def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
): ):
r""" r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by Prepare a dataloader for distributed training. The dataloader will be wrapped by
...@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase): ...@@ -484,7 +493,8 @@ class GeminiPlugin(DPPluginBase):
extra_dp_world_size = self.pg_mesh.size(DP_AXIS) extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
sampler = DistributedSampler( distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, dataset,
num_replicas=zero_world_size * extra_dp_world_size, num_replicas=zero_world_size * extra_dp_world_size,
rank=zero_rank * extra_dp_world_size + extra_dp_rank, rank=zero_rank * extra_dp_world_size + extra_dp_rank,
......
...@@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1205,7 +1205,16 @@ class HybridParallelPlugin(PipelinePluginBase):
return outputs return outputs
def prepare_dataloader( def prepare_dataloader(
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
): ):
r""" r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by Prepare a dataloader for distributed training. The dataloader will be wrapped by
...@@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1229,7 +1238,8 @@ class HybridParallelPlugin(PipelinePluginBase):
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
""" """
_kwargs = kwargs.copy() _kwargs = kwargs.copy()
sampler = DistributedSampler( distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment