dataset.py 16.8 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gc
import glob
import joblib
import numpy as np
import os
import torch
import torch.nn.functional as F
from allamo.configuration import AllamoConfiguration
from allamo.logging import logger

class AllamoDataset:
    """ In-Memory map-style dataset """

    def __init__(self, config: AllamoConfiguration, train_split=True, rank=None, world_size=None):
        self.rank = rank
        self.world_size = world_size
        self.data_dir = config.data_dir
        self.block_size = config.block_size
        self.sample_size = config.block_size + 1 
        self.ignore_index = config.ignore_index
        self.pad_token_id = config.pad_token_id
        self.weighted_loss = config.weighted_loss
        self.training_type = config.training_type
        self.data = None
        self.data_in_alm_format = False
        self.dataset_files = self.get_dataset_files(config, train_split)
        self.processed_files = []
        if config.dataset_train_processed_files_count > 0:
            self.processed_files = self.dataset_files[:config.dataset_train_processed_files_count]
        
    def get_dataset_files(self, config, train_split):
        dataset_files = []
        if train_split and config.dataset_train_files:
            dataset_files = config.dataset_train_files.split(',')
        elif not train_split and config.dataset_validation_files:
            dataset_files = config.dataset_validation_files.split(',')
        elif config.dataset:
            dataset_dir = os.path.join(config.data_dir, config.dataset)
            prefix = config.dataset_train_file_prefix if train_split else config.dataset_validation_file_prefix
            for dataset_file in glob.glob(os.path.join(dataset_dir, "*.*")):
                if self.is_file_type_supported(dataset_file) and os.path.basename(dataset_file).startswith(prefix):
                    dataset_files.append(dataset_file)
            logger.info(f"Found {len(dataset_files)} files in {dataset_dir} with prefix '{prefix}'")
        if dataset_files:
            return sorted(dataset_files)
        elif train_split:
            raise Exception('Training dataset files not found!')
        else:
            return []
    
    def is_file_type_supported(self, dataset_file):
        return dataset_file.endswith('.bin') or dataset_file.endswith('.pt') or dataset_file.endswith('.alm')
    
    def load_next_dataset(self):
        self.data = None
        gc.collect()
        for ds_file in self.dataset_files:
            if ds_file not in self.processed_files:
                if self.load_dataset_file(ds_file):
                    return True
        return False
                
    def load_dataset_file(self, load_dataset_file):
        self.processed_files.append(load_dataset_file)
        new_data = None
        if load_dataset_file.endswith('.bin'):
chenzk's avatar
v1.0.3  
chenzk committed
67
            assert self.training_type == 'pre', 'NumPy format is supported only for pre-training'
chenzk's avatar
v1.0  
chenzk committed
68
69
70
71
72
73
74
75
76
77
78
79
            step_size = self.world_size * self.sample_size
            new_data = torch.from_numpy(np.fromfile(load_dataset_file, dtype=np.uint16).astype(np.int16))
            if step_size > len(new_data):
                logger.warning(
                    f"Dataset file {load_dataset_file} does not have enough data and will be ignored. "
                    f"Expected at least {step_size} tokens but found only {len(new_data)}"
                )
                return False
            new_data = self.align_and_transform_continuous_data_to_samples(new_data, step_size)
            new_data = self.limit_samples_to_rank(new_data)
        elif load_dataset_file.endswith('.pt'):
            assert self.training_type != 'dpo', 'DPO training only supports the ALM format'
chenzk's avatar
v1.0.3  
chenzk committed
80
            new_data = torch.load(load_dataset_file, map_location='cpu', weights_only=True)
chenzk's avatar
v1.0  
chenzk committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            if isinstance(new_data, torch.Tensor):
                step_size = self.world_size * self.sample_size
                if step_size > len(new_data):
                    logger.warning(
                        f"Dataset file {load_dataset_file} does not have enough data and will be ignored. "
                        f"Expected at least {step_size} tokens but found only {len(new_data)}"
                    )
                    return False
                new_data = self.align_and_transform_continuous_data_to_samples(new_data, step_size)
                new_data = self.limit_samples_to_rank(new_data)
            else:
                new_data = self.align_and_limit_to_rank(new_data, load_dataset_file)
                if new_data:
                    self.pad_or_truncate_to_block_size(new_data)
        elif load_dataset_file.endswith('.alm'):
            new_data = joblib.load(load_dataset_file)
            new_data = self.align_and_limit_to_rank(new_data, load_dataset_file)
        
        if new_data:
            self.data = new_data
            self.data_in_alm_format = load_dataset_file.endswith('.alm')
            logger.info(f"New dataset file {load_dataset_file} loaded. Processed files: {len(self.processed_files)}")
            gc.collect()
            return True
        else:
            return False
        
    def align_and_limit_to_rank(self, new_data, load_dataset_file):
        if isinstance(new_data, list):
            if self.world_size > len(new_data):
                logger.warning(
                    f"Dataset file {load_dataset_file} does not have enough data and will be ignored. "
                    f"Expected at least {self.world_size} samples but found only {len(new_data)}"
                )
                return None
            new_data = self.align_data_to_step_size(new_data, self.world_size)
            new_data = self.limit_samples_to_rank(new_data)
        else:
            logger.info(f"Unsupported format of {load_dataset_file}!")
            new_data = None
        return new_data
    
    def align_data_to_step_size(self, data, step_size):
        target_length = ((len(data) + step_size - 1) // step_size) * step_size
        padding_length = target_length - len(data)
        if padding_length > 0:
            pre_size = len(data)
            if isinstance(data, list):
                data.extend(data[:padding_length])
            else:
                # FIXME: this operation is highly inefficient - it duplicates data in memory
                data = torch.concat((data, data[:padding_length]))
            logger.info(f"Data aligned. Pre-alignment size: {pre_size}, "
                             f"post-alignment size: {len(data)}, "
                             f"padding added: {padding_length}")
        return data
        
    def align_and_transform_continuous_data_to_samples(self, data, step_size):
        target_length = ((len(data) + step_size - 1) // step_size) * step_size
        padding_length = target_length - len(data)
        if padding_length > 0:
            pre_size = len(data)
            result = [data[i:i + self.sample_size] for i in range(0, (target_length - step_size), self.sample_size)]
            data = torch.concat((data[(target_length - step_size):], data[:padding_length]))
            result.extend([data[i:i + self.sample_size] for i in range(0, len(data), self.sample_size)])
            logger.info(f"Continuous data aligned and transformed to {len(result)} samples. "
                        f"Pre-alignment size: {pre_size}, "
                        f"post-alignment size: {target_length}, "
                        f"padding added: {padding_length}")
            return result
        else:
            return [data[i:i + self.sample_size] for i in range(0, len(data), self.sample_size)]
        
    def pad_or_truncate_to_block_size(self, data):
        """
        Adds padding to instructions to maintain a consistent input shape, avoiding recompilations.
        This method ensures all instructions have a uniform length matching the block size.
        By doing so, it prevents the need for frequent recompilations that occur due to
        dynamic input shapes, enhancing computational efficiency and stability.
        """
        for idx in range(len(data)):
            if isinstance(data[idx], dict):
                if 'input_ids' not in data[idx]:
                    raise Exception(f"'input_ids' field not found in sample! Available keys: {', '.join(data[idx].keys())}")
                elif isinstance(data[idx]['input_ids'], np.ndarray):
                    data[idx]['input_ids'] = torch.from_numpy(data[idx]['input_ids'])
                if 'target_ids' not in data[idx]:
                    data[idx]['target_ids'] = data[idx]['input_ids'][1:]
                elif isinstance(data[idx]['target_ids'], np.ndarray):
                    data[idx]['target_ids'] = torch.from_numpy(data[idx]['target_ids'])
                
                if self.weighted_loss:
                    if 'target_weights' not in data[idx]:
                        data[idx]['target_weights'] = torch.where(data[idx]['target_ids'] == self.ignore_index, 0, 1)
                    elif isinstance(data[idx]['target_weights'], np.ndarray):
                        data[idx]['target_weights'] = torch.from_numpy(data[idx]['target_weights'])
                elif 'target_weights' in data[idx]:
                    del data[idx]['target_weights']
                    
                if len(data[idx]['input_ids']) >= self.sample_size: # block_size = sample_size - 1
                    data[idx]['input_ids'] = data[idx]['input_ids'][:self.sample_size-1]
                elif self.pad_token_id >= 0 and len(data[idx]['input_ids']) < self.sample_size-1:
                    padding = self.sample_size - 1 - len(data[idx]['input_ids'])
                    data[idx]['input_ids'] = torch.cat([data[idx]['input_ids'], torch.full((padding,), self.ignore_index)], dim=0)
                
                if len(data[idx]['target_ids']) >= self.sample_size:
                    data[idx]['target_ids'] = data[idx]['target_ids'][:self.sample_size-1]
                elif self.pad_token_id >= 0 and len(data[idx]['target_ids']) < self.sample_size-1:
                    padding = self.sample_size - 1 - len(data[idx]['target_ids'])
                    data[idx]['target_ids'] = torch.cat([data[idx]['target_ids'], torch.full((padding,), self.ignore_index)], dim=0)
                
                if self.weighted_loss:
                    if len(data[idx]['target_weights']) >= self.sample_size:
                        data[idx]['target_weights'] = data[idx]['target_weights'][:self.sample_size-1]
                    elif self.pad_token_id >= 0 and len(data[idx]['target_weights']) < self.sample_size-1:
                        padding = self.sample_size - 1 - len(data[idx]['target_weights'])
                        data[idx]['target_weights'] = torch.cat([data[idx]['target_weights'], torch.full((padding,), 0)], dim=0)
                
                assert len(data[idx]['input_ids']) == len(data[idx]['target_ids'])
                if self.weighted_loss:
                    assert len(data[idx]['input_ids']) == len(data[idx]['target_weights'])
            else:
                if len(data[idx]) > self.sample_size:
                    data[idx] = data[idx][:self.sample_size]
                if self.pad_token_id >= 0:
                    if len(data[idx]) < self.sample_size:
                        padding = self.sample_size - len(data[idx])
                        data[idx] = torch.cat([data[idx], torch.full((padding,), self.ignore_index)], dim=0)
                    input_ids = data[idx][:-1]
                    target_ids = data[idx][1:]
                    target_weights = torch.where(target_ids == self.ignore_index, 0, 1)
                    input_ids = input_ids.masked_fill(input_ids == self.ignore_index, self.pad_token_id)
                    data[idx] = {'input_ids': input_ids, 'target_ids': target_ids, 'target_weights': target_weights}
        
    def limit_samples_to_rank(self, samples):
        return list(s for s in samples[self.rank::self.world_size]) if self.world_size > 1 else samples
        
    def has_data(self):
        return self.data and len(self.data) > 0
    
    def prepare_alm_dpo_sample(self, sample):
        result = {
            'chosen_input_ids': torch.from_numpy(sample['chosen_input_ids']),
            'chosen_target_ids': torch.from_numpy(sample['chosen_target_ids']),
            'rejected_input_ids': torch.from_numpy(sample['rejected_input_ids']),
            'rejected_target_ids': torch.from_numpy(sample['rejected_target_ids'])
        }
        if "reference_chosen_logps" in sample and "reference_rejected_logps" in sample:
            result["reference_chosen_logps"] = torch.tensor(sample['reference_chosen_logps'])
            result["reference_rejected_logps"] = torch.tensor(sample['reference_rejected_logps'])
        
        if self.pad_token_id >= 0:
            if len(result['chosen_input_ids']) < self.block_size:
                result['chosen_input_ids'] = F.pad(result['chosen_input_ids'], (0, self.block_size - len(result['chosen_input_ids'])), value=self.pad_token_id)
            if len(result['chosen_target_ids']) < self.block_size:
                result['chosen_target_ids'] = F.pad(result['chosen_target_ids'], (0, self.block_size - len(result['chosen_target_ids'])), value=self.ignore_index)
            if len(result['rejected_input_ids']) < self.block_size:
                result['rejected_input_ids'] = F.pad(result['rejected_input_ids'], (0, self.block_size - len(result['rejected_input_ids'])), value=self.pad_token_id)
            if len(result['rejected_target_ids']) < self.block_size:
                result['rejected_target_ids'] = F.pad(result['rejected_target_ids'], (0, self.block_size - len(result['rejected_target_ids'])), value=self.ignore_index)
        
        return result
    
    def prepare_alm_sample(self, sample):
        """
        Assumes input sample contains at least 'input_ids' and 'target_ids' fields. 
        When the weighted loss is active, 'target_weights' field is required.
        When samples are packed, it is assumed that a list of sequence lengths will be available
        in the "seq_lens" field. This information will be used to create the attention mask.
        If pad_token_id is set in the configuration, it is assumed that the sample list
        did not have padding and samples are of length up to block_size.
        """
        if self.training_type == 'dpo':
            return self.prepare_alm_dpo_sample(sample)
        
        result = {
            'input_ids': torch.from_numpy(sample['input_ids']),
            'target_ids': torch.from_numpy(sample['target_ids'])
        }
        if self.weighted_loss:
            if 'target_weights' in sample:
                result['target_weights'] = torch.from_numpy(sample['target_weights'])
            else:
                result['target_weights'] = torch.where(result['target_ids'] == self.ignore_index, 0, 1)
        
        if self.pad_token_id >= 0:
            if len(result['input_ids']) < self.block_size:
                result['input_ids'] = F.pad(result['input_ids'], (0, self.block_size - len(result['input_ids'])), value=self.pad_token_id)
            if len(result['target_ids']) < self.block_size:
                result['target_ids'] = F.pad(result['target_ids'], (0, self.block_size - len(result['target_ids'])), value=self.ignore_index)
            if 'target_weights' in result and len(result['target_weights']) < self.block_size:
                result['target_weights'] = F.pad(result['target_weights'], (0, self.block_size - len(result['target_weights'])), value=0)
        
        if "seq_lens" in sample:
            total_seq_len = 0
            block_attn_masks = []
            sample_input_pos = []
            for seq_len in sample["seq_lens"]:
                sample_input_pos.extend(list(range(seq_len)))
                total_seq_len += seq_len
                
                # append lower triangular matrix for causal mask
                block_attn_masks.append(torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)))
                
            if total_seq_len < len(result['input_ids']):
                new_pos = sample_input_pos[-1] + 1
                num_pad = len(result['input_ids']) - total_seq_len
                sample_input_pos.extend(list(range(new_pos, new_pos + num_pad)))
                block_attn_masks.append(torch.eye(num_pad, num_pad, dtype=torch.bool))
            result['input_pos'] = torch.tensor(sample_input_pos)
            result['attn_mask'] = torch.block_diag(*block_attn_masks)
        return result
    
    def __len__(self):
        """ Size of currently loaded dataset file """
        return len(self.data) if self.data else 0
        
    def __getitem__(self, idx):
        result = None
        if isinstance(idx, slice):
            result = self.data[idx]
            if self.data_in_alm_format:
                result = list(self.prepare_alm_sample(s) for s in result)
        elif idx < self.__len__():
            result = self.data[idx]
            if self.data_in_alm_format:
                result = self.prepare_alm_sample(result)
        return result