Commit 1c065c06 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #77 from ModelTC/dev_fix

Fix.
parents b811c2be c926b08a
...@@ -94,9 +94,7 @@ class LNWeight(LNWeightTemplate): ...@@ -94,9 +94,7 @@ class LNWeight(LNWeightTemplate):
self.bias = None self.bias = None
def apply(self, input_tensor): def apply(self, input_tensor):
if self.weight is None or self.weight.dtype == torch.bfloat16: if GET_DTYPE() != "BF16":
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
else:
input_tensor = torch.nn.functional.layer_norm( input_tensor = torch.nn.functional.layer_norm(
input_tensor.float(), input_tensor.float(),
(input_tensor.shape[-1],), (input_tensor.shape[-1],),
...@@ -104,4 +102,6 @@ class LNWeight(LNWeightTemplate): ...@@ -104,4 +102,6 @@ class LNWeight(LNWeightTemplate):
self.bias, self.bias,
self.eps, self.eps,
).to(torch.bfloat16) ).to(torch.bfloat16)
else:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor return input_tensor
\ No newline at end of file
...@@ -143,9 +143,8 @@ class WanModel: ...@@ -143,9 +143,8 @@ class WanModel:
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16" use_bf16 = GET_DTYPE() == "BF16"
# Some layers run with float32 to achieve high accuracy # Some layers run with float32 to achieve high accuracy
skip_bf16 = {"norm", "embedding", "modulation", "time"} skip_bf16 = {"norm", "embedding", "modulation", "time", "img_emb.proj.0", "img_emb.proj.4"}
if weight_dict is None: if weight_dict is None:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized or self.weight_auto_quant:
self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16) self.original_weight_dict = self._load_ckpt(use_bf16, skip_bf16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment