Unverified Commit 4ad7a1f5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Chore] create a utility for calculating the expected number of shards. (#8692)

create a utility for calculating the expected number of shards.
parent 1f81fbe2
...@@ -55,6 +55,15 @@ from diffusers.utils.testing_utils import ( ...@@ -55,6 +55,15 @@ from diffusers.utils.testing_utils import (
from ..others.test_utils import TOKEN, USER, is_staging_test from ..others.test_utils import TOKEN, USER, is_staging_test
def caculate_expected_num_shards(index_map_path):
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
# Will be run via run_test_in_subprocess # Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None error = None
...@@ -888,12 +897,7 @@ class ModelTesterMixin: ...@@ -888,12 +897,7 @@ class ModelTesterMixin:
# Now check if the right number of shards exists. First, let's get the number of shards. # Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it # Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it. # instead of hardcoding it.
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards) self.assertTrue(actual_num_shards == expected_num_shards)
...@@ -924,12 +928,7 @@ class ModelTesterMixin: ...@@ -924,12 +928,7 @@ class ModelTesterMixin:
# Now check if the right number of shards exists. First, let's get the number of shards. # Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it # Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it. # instead of hardcoding it.
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards) self.assertTrue(actual_num_shards == expected_num_shards)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment