import contextlib import gc import sys from functools import partial from pathlib import Path from typing import Dict, Literal, Optional, Tuple, Union from dataclasses import asdict import json import torch # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) from lit_gpt import Config from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load # from scripts.convert_hf_checkpoint import layer_template, load_param def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: split = layer_name.split(".") number = int(split[idx]) split[idx] = "{}" from_name = ".".join(split) return from_name, number def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` print(f"Loading {name!r} into RAM") param = param._load_tensor() if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: print(f"Converting {name!r} from {param.dtype} to {dtype}") param = param.to(dtype) return param def copy_weights_falcon( size: Literal["7b", "40b"], state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ): weight_map = { "transformer.wte.weight": "transformer.word_embeddings.weight", "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", "transformer.ln_f.bias": "transformer.ln_f.bias", "transformer.ln_f.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size if size == "7b": weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", } ) elif size == "40b": weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", } ) else: raise NotImplementedError for name, param in lit_weights.items(): if "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name].format(number) else: to_name = weight_map[name] param = load_param(param, name, None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gpt_neox( state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "gpt_neox.embed_in.weight", "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", "lm_head.weight": "embed_out.weight", } for name, param in lit_weights.items(): if "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name].format(number) else: to_name = weight_map[name] param = load_param(param, name, None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_llama( config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ): weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.swiglu.w1.weight": "model.layers.{}.mlp.gate_proj.weight", "transformer.h.{}.mlp.swiglu.w2.weight": "model.layers.{}.mlp.up_proj.weight", "transformer.h.{}.mlp.swiglu.w3.weight": "model.layers.{}.mlp.down_proj.weight", "transformer.ln_f.weight": "model.norm.weight", "lm_head.weight": "lm_head.weight", } for name, param in lit_weights.items(): if name.endswith(".attn.attn.weight"): from_name, number = layer_template(name, 2) q = "model.layers.{}.self_attn.q_proj.weight".format(number) k = "model.layers.{}.self_attn.k_proj.weight".format(number) v = "model.layers.{}.self_attn.v_proj.weight".format(number) qkv = load_param(param, name,None) qp, kp, vp = tensor_split(qkv, config) for to_name, param in zip((q, k, v), (qp, kp, vp)): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param elif "transformer.h" in name: from_name, number = layer_template(name, 2) to_name = weight_map[from_name] if to_name is None: continue to_name = to_name.format(number) param = load_param(param, name,None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param else: to_name = weight_map[name] param = load_param(param, name, None) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def tensor_split( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def kstart(start, blen, klen) -> int: """returns start index of keys in batch""" return start + (blen - (klen * 2)) def vstart(start, blen, klen) -> int: """returns start index of values in batch""" return start + blen - klen def vend(start, blen) -> int: """returns last index of values in batch""" return start + blen # num observations nobs = param.shape[0] # batch length blen = nobs // config.n_query_groups # key length in batch klen = config.head_size # value length in batch vlen = config.head_size # the starting index of each new batch starts = range(0, nobs, blen) # the indices to splice on splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] qc = () kc = () vc = () for splice in splices: qs, ks, vs, ve = splice qc += (param[qs:ks, :],) kc += (param[ks:vs, :],) vc += (param[vs:ve, :],) q = torch.cat(qc) k = torch.cat(kc) v = torch.cat(vc) return q, k, v def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return lit_weights.get("model", lit_weights) def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: weight_names = {wk.split(".")[-1] for wk in lit_weights} # LoRA or QLoRA if any("lora" in wn for wn in weight_names): raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") # adapter v2. adapter_bias will only be in adapter_v2 elif "adapter_bias" in weight_names: raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") # adapter. gating_factor is in adapter and adapter_v2 elif "gating_factor" in weight_names: raise NotImplementedError("Converting models finetuned with adapter not yet supported.") def get_tinyllama_init_hf_config() -> dict: return { "architectures": ["LlamaForCausalLM"], "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": None, "initializer_range": 0.02, "intermediate_size": None, "max_position_embeddings": None, "model_type": "llama", "num_attention_heads": None, "num_hidden_layers": None, "num_key_value_heads": None, "pretraining_tp": 1, "rms_norm_eps": None, "rope_scaling": None, "tie_word_embeddings": False, "torch_dtype": "float32", "transformers_version": "4.31.0.dev0", "use_cache": True, "vocab_size": None, } def convert_config_lit_to_hf(lit_config_dict: dict) -> dict: lit_hf_mapping = { "block_size": "max_position_embeddings", "vocab_size": "vocab_size", "n_layer": "num_hidden_layers", "n_embd": "hidden_size", "n_head": "num_attention_heads", "n_query_groups": "num_key_value_heads", "intermediate_size": "intermediate_size", "norm_eps": "rms_norm_eps", } hf_config_dict = get_tinyllama_init_hf_config() for lit_key, hf_key in lit_hf_mapping.items(): hf_config_dict[hf_key] = lit_config_dict[lit_key] return hf_config_dict @torch.inference_mode() def convert_lit_checkpoint(*, checkpoint_name: str, out_dir: Path, model_name: str, model_only: bool = True) -> None: config = Config.from_name(model_name) if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") elif config._mlp_class == "LLaMAMLP": copy_fn = partial(copy_weights_llama, config) else: copy_fn = copy_weights_gpt_neox # initialize a new empty state dict to hold our new weights sd = {} # checkpoint_name cannot be hardcoded because there exists different outputs such as # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") pth_file = out_dir / checkpoint_name bin_file = pth_file.with_suffix(".bin") with incremental_save(bin_file) as saver: with contextlib.ExitStack() as stack: lit_weights = stack.enter_context(lazy_load(pth_file)) lit_weights = maybe_unwrap_state_dict(lit_weights) check_conversion_supported(lit_weights) # Incremental save will trigger error copy_fn(sd, lit_weights, saver=None) gc.collect() saver.save(sd) # convert lit config file to hf-style if not model_only: print('Converting config file...') lit_config = asdict(config) hf_config = convert_config_lit_to_hf(lit_config) config_path = out_dir / "config.json" with open(config_path, "w") as f: json.dump(hf_config, f, indent=4) if __name__ == "__main__": from jsonargparse import CLI CLI(convert_lit_checkpoint, as_positional=False)