import dataclasses from typing import Dict, List, Union import numpy as np import torch from nanotron import distributed as dist from nanotron.parallel.context import ParallelContext from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer @dataclasses.dataclass class NanosetDataCollatorForCLM: """ Data collator used for causal language modeling with Nanosets dataset. - input_pp_rank: Discards last input id token - output_pp_rank: Discards first label id token - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. """ sequence_length: int input_pp_rank: int output_pp_rank: int parallel_context: ParallelContext def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) if current_pp_rank not in [ self.input_pp_rank, self.output_pp_rank, ]: assert all(len(example) == 0 for example in examples) return { "input_ids": TensorPointer(group_rank=self.input_pp_rank), "input_mask": TensorPointer(group_rank=self.input_pp_rank), "label_ids": TensorPointer(group_rank=self.output_pp_rank), "label_mask": TensorPointer(group_rank=self.output_pp_rank), } # Make sure we load only what's necessary, ie we only load a `input_ids` column. assert all(list(example.keys()) == ["input_ids"] for example in examples) # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) batch_size, expanded_input_length = input_ids.shape result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) assert ( expanded_input_length == self.sequence_length + 1 ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" # Process inputs: last token is the label if current_pp_rank == self.input_pp_rank: result["input_ids"] = input_ids[:, :-1] result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) # Process labels: shift them to the left if current_pp_rank == self.output_pp_rank: result["label_ids"] = input_ids[:, 1:] result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: raise ValueError( f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" f" {self.sequence_length}." ) if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: raise ValueError( f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" f" {self.sequence_length}." ) return result