Unverified Commit 187c2a06 authored by Yuchao Dai's avatar Yuchao Dai Committed by GitHub
Browse files

Fix E1136 (#563)

parent 229080b9
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from typing import Dict, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
return state_dict return state_dict
def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config): def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to """Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model. the state_dict of a standard GPT model.
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import re import re
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Union from typing import Dict, List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -17,8 +17,8 @@ from einops import rearrange ...@@ -17,8 +17,8 @@ from einops import rearrange
def remap_state_dict_meta_llama( def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in Meta format to standard GPT format. """Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place. This function modifies state_dict in place.
...@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama( ...@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama(
def remap_state_dict_hf_llama( def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format. """Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place. This function modifies state_dict in place.
...@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama( ...@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama(
def inv_remap_state_dict_hf_llama( def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config state_dict: Dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format. """Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
...@@ -382,7 +382,7 @@ def config_from_checkpoint( ...@@ -382,7 +382,7 @@ def config_from_checkpoint(
def state_dicts_from_checkpoint( def state_dicts_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str checkpoint_path: Union[str, os.PathLike], model_name: str
) -> list[dict]: ) -> List[dict]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong # Need to sort, otherwise we mess up the ordering and the weights are wrong
return [ return [
torch.load(path, map_location="cpu") torch.load(path, map_location="cpu")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment