"examples/vscode:/vscode.git/clone" did not exist on "1096f88e2b8d67d65ed21ee23c6aed2e3e9756d6"
Commit be3468dd authored by comfyanonymous's avatar comfyanonymous
Browse files

Less useless downcasting.

parent ca82ade7
...@@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.inner_name = inner_name self.inner_name = inner_name
if dtype is not None: if dtype is not None:
self.transformer.to(dtype)
inner_model = getattr(self.transformer, self.inner_name) inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"): if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32) embeddings_bak = inner_model.embeddings.to(torch.float32)
inner_model.embeddings = None
self.transformer.to(dtype)
inner_model.embeddings = embeddings_bak
else: else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
self.transformer.to(dtype)
self.transformer.set_input_embeddings(previous_inputs)
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment