Unverified Commit 187c2a06 authored by Yuchao Dai's avatar Yuchao Dai Committed by GitHub
Browse files

Fix E1136 (#563)

parent 229080b9
......@@ -6,6 +6,7 @@ import re
from collections import OrderedDict, namedtuple
from collections.abc import Sequence
from functools import partial
from typing import Dict, List
import torch
import torch.nn as nn
......@@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
return state_dict
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config):
def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model.
......
......@@ -6,7 +6,7 @@ import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Union
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
......@@ -17,8 +17,8 @@ from einops import rearrange
def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
......@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama(
def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
......@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama(
def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
......@@ -382,7 +382,7 @@ def config_from_checkpoint(
def state_dicts_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]:
) -> List[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [
torch.load(path, map_location="cpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment