# Copyright (c) 2024 westlake-repl # Copyright (c) 2024 Bytedance Ltd. and/or its affiliate # SPDX-License-Identifier: MIT # This file has been modified by Junyi Chen. # # Original file was released under MIT, with the full license text # available at https://choosealicense.com/licenses/mit/. # # This modified file is released under the same license. import torch import numpy as np from torch.utils.data._utils.collate import default_collate import re try: from torch._six import string_classes except: string_classes = str import collections np_str_obj_array_pattern = re.compile(r"[SaUO]") default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}" ) def customize_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: customize_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size return batch def seq_eval_collate(batch): item_seq = [] item_target = [] time_seq = [] history_i = [] for item in batch: history_i.append(item[0]) item_seq.append(item[1]) item_target.append(item[2]) time_seq.append(item[3]) history_u = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_i)]) history_i = torch.cat(history_i) item_seq = torch.tensor(item_seq) # [batch, len] item_target = torch.tensor(item_target) # [batch] time_seq = torch.tensor(time_seq) # [batch] positive_u = torch.arange(item_seq.shape[0]) # [batch] # return item_seq, None, positive_u, item_target return item_seq, time_seq, (history_u, history_i), positive_u, item_target def customize_rmpad_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): output = {} for key in elem: if any(['_input_ids' in key, '_cu_input_lens' in key, '_position_ids' in key]): output[key] = torch.concat([d[key] for d in batch], dim=0) else: output[key] = customize_collate([d[key] for d in batch]) return output elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size return batch