# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import argparse import os import sys import torch # Add megatron to the path. sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) def combine(input_files, module_prefixes, output_files): num_inputs_per_output = int(len(input_files) / len(output_files)) for output_idx, output_file in enumerate(output_files): combined_state_dict = None lb = output_idx * num_inputs_per_output ub = (output_idx + 1) * num_inputs_per_output current_input_files = input_files[lb:ub] current_module_prefixes = module_prefixes[lb:ub] for i, (input_file, module_prefix) in enumerate( zip(current_input_files, current_module_prefixes) ): # initialize the combined state dict using the first provided input file current_state_dict = torch.load(input_file) if i == 0: combined_state_dict = current_state_dict.copy() combined_state_dict["model"] = dict() # copy model state dict and prefix names with the given module keys. for k, v in current_state_dict["model"].items(): combined_state_dict["model"]["%s.%s" % (module_prefix, k)] = v output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) torch.save(combined_state_dict, output_file) print("saved:", output_file) print("done.") if __name__ == "__main__": parser = argparse.ArgumentParser( description=""" Combine multiple state dicts into a single state dict. The combined state dict is first initialized by taking a copy of the first provided input state dict. To avoid conflicts in model parameter names, a prefix must be provided for each input file. Model parameter names will be renamed from to .. Example usage: python combine_state_dicts.py --input language_model.pt vision_model.pt --prefixes language_model vision_model --output multimodal.pt """, formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--input", nargs="*", required=True, help="paths to input state dict files") parser.add_argument( "--prefixes", nargs="*", required=True, help="prefixes to use with each input model's parameters", ) parser.add_argument( "--output", nargs="*", required=True, help="path(s) to output state dict file" ) args = parser.parse_args() assert len(args.input) > 1, "must provide more than 1 input model to combine" assert len(args.input) == len(args.prefixes), "each input model must have a corresponding key" assert ( len(args.input) % len(args.output) == 0 ), "each output file must use the same number of input files" combine(args.input, args.prefixes, args.output)