import numpy as np import torch IGNORE_INDEX = -100 def pad_data_collator(features, pad_id=0): first = features[0] batch = {} batch_lens = [feat['input_ids'].shape for feat in features] max_item_length = max(batch_lens)[0] for idx in range(len(features)): feat = features[idx] temp_input_ids = torch.LongTensor([pad_id] * max_item_length) temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] feat['input_ids'] = temp_input_ids temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) temp_labels[:feat['labels'].shape[0]] = feat['labels'] feat['labels'] = temp_labels feat['attention_mask'] = feat['input_ids'].ne(pad_id) # Special handling for labels. # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) if 'label' in first and first['label'] is not None: label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] dtype = torch.long if isinstance(label, int) else torch.float batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) elif 'label_ids' in first and first['label_ids'] is not None: if isinstance(first['label_ids'], torch.Tensor): batch['labels'] = torch.stack([f['label_ids'] for f in features]) else: dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) # Handling of all other possible keys. # Again, we will use the first element to figure out which key/values are not None for this model. for k, v in first.items(): if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) elif isinstance(v, np.ndarray): batch[k] = torch.tensor(np.stack([f[k] for f in features])) else: batch[k] = torch.tensor([f[k] for f in features]) return batch def concat_pad_data_collator(features, pad_id=0): first = features[0] batch = {} batch_lens = [feat['input_ids'].shape for feat in features] max_item_length = max(batch_lens)[0] for idx in range(len(features)): feat = features[idx] temp_input_ids = torch.LongTensor([pad_id] * max_item_length) temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] feat['input_ids'] = temp_input_ids temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) temp_labels[:feat['labels'].shape[0]] = feat['labels'] feat['labels'] = temp_labels feat['attention_mask'] = feat['input_ids'].ne(pad_id) # Special handling for labels. # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) if 'label' in first and first['label'] is not None: label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] dtype = torch.long if isinstance(label, int) else torch.float batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) elif 'label_ids' in first and first['label_ids'] is not None: if isinstance(first['label_ids'], torch.Tensor): batch['labels'] = torch.stack([f['label_ids'] for f in features]) else: dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) # Handling of all other possible keys. # Again, we will use the first element to figure out which key/values are not None for this model. for k, v in first.items(): if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ v is not None and not isinstance(v, str): if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) elif isinstance(v, np.ndarray): batch[k] = torch.tensor(np.stack([f[k] for f in features])) else: batch[k] = torch.tensor([f[k] for f in features]) if k in ('pixel_values', 'image_flags'): if isinstance(v, torch.Tensor): batch[k] = torch.concat([f[k] for f in features]) elif isinstance(v, np.ndarray): batch[k] = torch.concat(np.stack([f[k] for f in features])) else: batch[k] = torch.concat([f[k] for f in features]) return batch