".github/vscode:/vscode.git/clone" did not exist on "4ca9b9cc29fefaa899cba67d61a8252ae9f16c0d"
Commit 57926635 authored by comfyanonymous's avatar comfyanonymous
Browse files

Switch text encoder to manual cast.

Use fp16 text encoder weights for CPU inference to lower memory usage.
parent 69033081
...@@ -503,6 +503,9 @@ def text_encoder_dtype(device=None): ...@@ -503,6 +503,9 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc: elif args.fp32_text_enc:
return torch.float32 return torch.float32
if is_device_cpu(device):
return torch.float16
if should_use_fp16(device, prioritize_performance=False): if should_use_fp16(device, prioritize_performance=False):
return torch.float16 return torch.float16
else: else:
......
...@@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs): ...@@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
else: else:
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias
class manual_cast:
class Linear(Linear):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(Conv2d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class Conv3d(Conv3d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class GroupNorm(GroupNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class LayerNorm(LayerNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
@contextmanager @contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear old_torch_nn_linear = torch.nn.Linear
......
...@@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f: with open(textmodel_json_config) as f:
config = json.load(f) config = json.load(f)
self.transformer = model_class(config, dtype, device, comfy.ops) self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
self.max_length = max_length self.max_length = max_length
...@@ -160,12 +160,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -160,12 +160,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if self.transformer.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None attention_mask = None
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens) attention_mask = torch.zeros_like(tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment