"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "877a88c57e5f25cec3c9b3748bd0525fceec4908"
Unverified Commit 6f74ef55 authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49 (#10816)

* Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49

* Default torch_dtype and warning
parent 9c7e2051
...@@ -92,9 +92,13 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -92,9 +92,13 @@ class CheckpointMergerPipeline(DiffusionPipeline):
token = kwargs.pop("token", None) token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")
alpha = kwargs.pop("alpha", 0.5) alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None) interp = kwargs.pop("interp", None)
......
...@@ -360,11 +360,17 @@ class FromSingleFileMixin: ...@@ -360,11 +360,17 @@ class FromSingleFileMixin:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False is_legacy_loading = False
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
# We shouldn't allow configuring individual models components through a Pipeline creation method # We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated # These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None) scaling_factor = kwargs.get("scaling_factor", None)
......
...@@ -240,11 +240,17 @@ class FromOriginalModelMixin: ...@@ -240,11 +240,17 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None) config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32)
quantization_config = kwargs.pop("quantization_config", None) quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None) device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
if isinstance(pretrained_model_link_or_path_or_dict, dict): if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict checkpoint = pretrained_model_link_or_path_or_dict
else: else:
......
...@@ -866,7 +866,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -866,7 +866,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", None) local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None) max_memory = kwargs.pop("max_memory", None)
...@@ -879,6 +879,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -879,6 +879,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
......
...@@ -685,7 +685,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -685,7 +685,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
token = kwargs.pop("token", None) token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False) from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32)
custom_pipeline = kwargs.pop("custom_pipeline", None) custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None) custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
...@@ -702,6 +702,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -702,6 +702,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None) use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
if low_cpu_mem_usage and not is_accelerate_available(): if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False low_cpu_mem_usage = False
logger.warning( logger.warning(
...@@ -1826,7 +1832,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1826,7 +1832,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
""" """
original_config = dict(pipeline.config) original_config = dict(pipeline.config)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32)
# derive the pipeline class to instantiate # derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None) custom_pipeline = kwargs.pop("custom_pipeline", None)
......
...@@ -89,7 +89,9 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -89,7 +89,9 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
sample_size=128, sample_size=128,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = { components = {
......
...@@ -93,7 +93,9 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -93,7 +93,9 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
sample_size=128, sample_size=128,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = { components = {
......
...@@ -98,7 +98,9 @@ class KolorsPAGPipelineFastTests( ...@@ -98,7 +98,9 @@ class KolorsPAGPipelineFastTests(
sample_size=128, sample_size=128,
) )
torch.manual_seed(0) torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = { components = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment