Unverified Commit 77950c48 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[wip/s2s] DistributedSortishSampler (#7056)

parent 51448673
...@@ -3,7 +3,6 @@ import glob ...@@ -3,7 +3,6 @@ import glob
import logging import logging
import os import os
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer): ...@@ -67,6 +66,8 @@ class SummarizationModule(BaseTransformer):
default_val_metric = "rouge2" default_val_metric = "rouge2"
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
if hparams.sortish_sampler and hparams.gpus > 1:
hparams.replace_sampler_ddp = False
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization") use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir) save_git_info(self.hparams.output_dir)
...@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer): ...@@ -93,9 +94,6 @@ class SummarizationModule(BaseTransformer):
"val": self.hparams.val_max_target_length, "val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length, "test": self.hparams.test_max_target_length,
} }
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
self.hparams.sortish_sampler = False
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
...@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer): ...@@ -257,8 +255,7 @@ class SummarizationModule(BaseTransformer):
dataset = self.get_dataset(type_path) dataset = self.get_dataset(type_path)
sampler = None sampler = None
if self.hparams.sortish_sampler and type_path == "train": if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__ sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False shuffle = False
dataloader = DataLoader( dataloader = DataLoader(
......
...@@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -149,9 +149,9 @@ class TestSummarizationDistiller(unittest.TestCase):
no_teacher=True, no_teacher=True,
freeze_encoder=True, freeze_encoder=True,
gpus=2, gpus=2,
sortish_sampler=False, sortish_sampler=True,
) )
self._test_distiller_cli(updates) self._test_distiller_cli(updates, check_contents=False)
def test_distill_no_teacher(self): def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
......
import itertools import itertools
import json import json
import linecache import linecache
import math
import os import os
import pickle import pickle
from logging import getLogger from logging import getLogger
...@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union ...@@ -10,6 +11,7 @@ from typing import Callable, Dict, Iterable, List, Union
import git import git
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from rouge_score import rouge_scorer, scoring from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu from sacrebleu import corpus_bleu
from torch import nn from torch import nn
...@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset): ...@@ -111,8 +113,11 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file): def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()] return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size): def make_sortish_sampler(self, batch_size, distributed=False):
return SortishSampler(self.src_lens, batch_size) if distributed:
return DistributedSortishSampler(self, batch_size)
else:
return SortishSampler(self.src_lens, batch_size)
def __getitem__(self, item): def __getitem__(self, item):
raise NotImplementedError("You must implement this") raise NotImplementedError("You must implement this")
...@@ -191,24 +196,77 @@ class SortishSampler(Sampler): ...@@ -191,24 +196,77 @@ class SortishSampler(Sampler):
def __init__(self, data, batch_size): def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size self.data, self.bs = data, batch_size
def key(self, i):
return self.data[i]
def __len__(self) -> int: def __len__(self) -> int:
return len(self.data) return len(self.data)
def __iter__(self): def __iter__(self):
idxs = np.random.permutation(len(self.data)) return iter(sortish_sampler_indices(self.data, self.bs))
sz = self.bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) def sortish_sampler_indices(data: List, bs: int) -> np.array:
sz = self.bs "Go through the text data by order of src length with a bit of randomness. From fastai repo."
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, def key_fn(i):
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. return data[i]
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx)) idxs = np.random.permutation(len(data))
return iter(sort_idx) sz = bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
sz = bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return sort_idx
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.batch_size = batch_size
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
indices = [available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
def get_indices_for_rank(self) -> np.array:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
available_indices = indices[self.rank : self.total_size : self.num_replicas]
return available_indices
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
logger = getLogger(__name__) logger = getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment