"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "827d6d6ef071029cfe82838a18dab046b5813976"
Unverified Commit 0b7b07ef authored by ANURAG BHANDARI's avatar ANURAG BHANDARI Committed by GitHub
Browse files

added type hints for Yolos Pytorch model (#19545)



* added type hints for Yolos Pytorch model

* make fixup

* Update src/transformers/models/yolos/convert_yolos_to_pytorch.py

* Update src/transformers/models/yolos/convert_yolos_to_pytorch.py

* Update src/transformers/models/yolos/convert_yolos_to_pytorch.py
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 3b3024da
...@@ -32,7 +32,7 @@ logging.set_verbosity_info() ...@@ -32,7 +32,7 @@ logging.set_verbosity_info()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_yolos_config(yolos_name): def get_yolos_config(yolos_name: str) -> YolosConfig:
config = YolosConfig() config = YolosConfig()
# size of the architecture # size of the architecture
...@@ -68,7 +68,7 @@ def get_yolos_config(yolos_name): ...@@ -68,7 +68,7 @@ def get_yolos_config(yolos_name):
# we split up the matrix of each encoder layer into queries, keys and values # we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, base_model=False): def read_in_q_k_v(state_dict: dict, config: YolosConfig, base_model: bool = False):
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias) # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
...@@ -86,7 +86,7 @@ def read_in_q_k_v(state_dict, config, base_model=False): ...@@ -86,7 +86,7 @@ def read_in_q_k_v(state_dict, config, base_model=False):
state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def rename_key(name): def rename_key(name: str) -> str:
if "backbone" in name: if "backbone" in name:
name = name.replace("backbone", "vit") name = name.replace("backbone", "vit")
if "cls_token" in name: if "cls_token" in name:
...@@ -123,7 +123,7 @@ def rename_key(name): ...@@ -123,7 +123,7 @@ def rename_key(name):
return name return name
def convert_state_dict(orig_state_dict, model): def convert_state_dict(orig_state_dict: dict, model: YolosForObjectDetection) -> dict:
for key in orig_state_dict.copy().keys(): for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key) val = orig_state_dict.pop(key)
...@@ -148,14 +148,16 @@ def convert_state_dict(orig_state_dict, model): ...@@ -148,14 +148,16 @@ def convert_state_dict(orig_state_dict, model):
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg" url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw) im = Image.open(requests.get(url, stream=True).raw)
return im return im
@torch.no_grad() @torch.no_grad()
def convert_yolos_checkpoint(yolos_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False): def convert_yolos_checkpoint(
yolos_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
""" """
Copy/paste/tweak model's weights to our YOLOS structure. Copy/paste/tweak model's weights to our YOLOS structure.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment