import os import torch from arguments import get_args from megatron import mpu from megatron.utils import ensure_directory_exists from megatron.utils import get_checkpoint_name from megatron.utils import get_checkpoint_tracker_filename from megatron.utils import vocab_size_with_padding def split_into_partitions(tensor, num_partitions, partition_dim, stride): per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions) per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) partitions_list = torch.split(tensor, per_partition_per_stride_size, dim=partition_dim) partitions = [] for i in range(num_partitions): partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim) partitions.append(partition) return partitions def merge_partitions(merged, partitions, partition_dim, stride): # Number and size of each partition. num_partitions = len(partitions) per_partition_size = None for partition in partitions: if per_partition_size is None: per_partition_size = partition.size(partition_dim) else: assert per_partition_size == partition.size(partition_dim) def concat_partitions(partitions_): with torch.no_grad(): if (per_partition_size * num_partitions) == merged.size( partition_dim): torch.cat(partitions_, dim=partition_dim, out=merged) else: print(' ***WARNING*** sizes do not match. Will cut ' 'the merged partitions by {} along dimension {} ' 'to reduce the size from {} to {} ...'.format( (per_partition_size * num_partitions) - \ merged.size(partition_dim), partition_dim, per_partition_size * num_partitions, merged.size(partition_dim))) merged_ = torch.cat(partitions_, dim=partition_dim) merged_split = torch.split(merged_, merged.size(partition_dim), dim=partition_dim) merged_ = merged_split[0] assert merged_.size(partition_dim) == merged.size(partition_dim) merged.data.copy_(merged_.data) # If stride is 1, then do simple concatination. if stride == 1: concat_partitions(partitions) return # For none unity strides, first split based on stride and then group. per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) # Chunk and build a list. chunks = None for i, partition in enumerate(partitions): chunk = torch.split(partition, per_partition_per_stride_size, dim=partition_dim) if chunks is None: chunks = [0]*(num_partitions*len(chunk)) chunks[i::num_partitions] = chunk # Concatinate. concat_partitions(chunks) return def get_model(model_type, args): if model_type == 'BERT': from pretrain_albert import model_provider args.tokentype_size = 2 elif model_type == 'GPT': from pretrain_gpt2 import model_provider else: raise Exception('unrecognized model type: {}'.format(model_type)) orig_vocab_size = args.vocab_size args.vocab_size = vocab_size_with_padding(args.vocab_size, args) model = model_provider(args) model = model.half() args.vocab_size = orig_vocab_size return model def get_parallel_checkpoint_name(path): tracker_filename = get_checkpoint_tracker_filename(path) iteration = 0 with open(tracker_filename, 'r') as f: metastring = f.read().strip() iteration = int(metastring) assert iteration > 0 checkpoint_name = get_checkpoint_name(path, iteration) return checkpoint_name, iteration def test_split_merge(): print('testing split and merge ...') #[QKV.ROW-COL] tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15], [1.21, 1.22, 1.23, 1.24, 1.25], [1.31, 1.32, 1.33, 1.34, 1.35], [1.41, 1.42, 1.43, 1.44, 1.45], [2.11, 2.12, 2.13, 2.14, 2.15], [2.21, 2.22, 2.23, 2.24, 2.25], [2.31, 2.32, 2.33, 2.34, 2.35], [2.41, 2.42, 2.43, 2.44, 2.45], [3.11, 3.12, 3.13, 3.14, 3.15], [3.21, 3.22, 3.23, 3.24, 3.25], [3.31, 3.32, 3.33, 3.34, 3.35], [3.41, 3.42, 3.43, 3.44, 3.45]]) num_partitions = 2 partition_dim = 0 stride = 3 partitions = split_into_partitions(tensor, num_partitions, partition_dim, stride) merged = torch.zeros_like(tensor) merge_partitions(merged, partitions, partition_dim, stride) max_error = (merged - tensor).abs().max() print(' > max error (should be zero): {}'.format(max_error)) def main(model_type): # Args args = get_args() print('\n merging model parallel partitions ...') assert args.vocab_size is not None print(' > number of partitions: {}'.format(args.model_parallel_size)) print(' > checkpoint path: {}'.format(args.load)) print(' > model parameters:') print(' number of tokens ................ {} '.format(args.vocab_size)) print(' number of layers ................ {}'.format(args.num_layers)) print(' hidden sise ..................... {}'.format(args.hidden_size)) print(' number of attention heads ....... {}'.format( args.num_attention_heads)) print(' maximum position embeddings ..... {}'.format( args.max_position_embeddings)) # Full model. print('> building the full model ...') mpu.initialize.set_model_parallel_world_size(1) mpu.initialize.set_model_parallel_rank(0) merged_model = get_model(model_type, args) # Build and load partitions. partitions = [] iteration = 0 mpu.initialize.set_model_parallel_world_size(args.model_parallel_size) for rank in range(args.model_parallel_size): mpu.initialize.set_model_parallel_rank(rank) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) print('> loading {} ...'.format(checkpoint_name)) model_ = get_model(model_type, args) sd = torch.load(checkpoint_name, map_location='cpu') model_.load_state_dict(sd['model']) partitions.append(model_) # Parameter generators so we can loop through them semiltaneouly. merged_params_gen = merged_model.named_parameters() partitions_params_gen = [partition.named_parameters() for partition in partitions] while True: try: # Get the params and check names. name, merged_param = next(merged_params_gen) print(' > working on {} ...'.format(name)) print(' merged type: {}, size: {}'.format( merged_param.dtype, list(merged_param.size()))) partitions_param = [] for rank, partition_params_gen in enumerate(partitions_params_gen): partition_name, partition_param = next(partition_params_gen) assert partition_name == name partitions_param.append(partition_param) print(' partition {} type: {}, size: {}'.format( rank, partition_param.dtype, list(partition_param.size()))) # For the non-parallel parameters, simply copy the rank 0 values. if not hasattr(merged_param, 'model_parallel'): print(' none-parallel parameter, simple copy from rank 0') with torch.no_grad(): merged_param.data.copy_(partitions_param[0].data) # For parallel parameters, merge the values else: print(' parallel parameter merge with stride {} along ' 'dimention {}'.format(merged_param.stride, merged_param.partition_dim)) merge_partitions(merged_param, partitions_param, merged_param.partition_dim, merged_param.stride) except StopIteration: break # Save the model. mpu.initialize.set_model_parallel_rank(0) sd = {} sd['model'] = merged_model.state_dict_for_save_checkpoint() sd['iteration'] = iteration merged_path = os.path.join(args.load, 'merged') checkpoint_name = get_checkpoint_name(merged_path, iteration) ensure_directory_exists(checkpoint_name) print('> saving merged model to {}'.format(checkpoint_name)) torch.save(sd, checkpoint_name) print('done :-)') if __name__ == '__main__': main('BERT')