Commit a4b666ca authored by gushiqiao's avatar gushiqiao
Browse files

Fix bugs

parent 06f9891a
...@@ -1022,4 +1022,4 @@ if __name__ == "__main__": ...@@ -1022,4 +1022,4 @@ if __name__ == "__main__":
model_path = args.model_path model_path = args.model_path
model_cls = args.model_cls model_cls = args.model_cls
main() main()
\ No newline at end of file
...@@ -1022,4 +1022,4 @@ if __name__ == "__main__": ...@@ -1022,4 +1022,4 @@ if __name__ == "__main__":
model_path = args.model_path model_path = args.model_path
model_cls = args.model_cls model_cls = args.model_cls
main() main()
\ No newline at end of file
...@@ -360,4 +360,4 @@ class MemoryBuffer: ...@@ -360,4 +360,4 @@ class MemoryBuffer:
self.insertion_index = 0 self.insertion_index = 0
self.used_mem = 0 self.used_mem = 0
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
\ No newline at end of file
...@@ -503,4 +503,4 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -503,4 +503,4 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.clean_cuda_cache: if self.clean_cuda_cache:
del y, c_gate_msa del y, c_gate_msa
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x return x
\ No newline at end of file
...@@ -169,4 +169,4 @@ def sinusoidal_embedding_1d(dim, position): ...@@ -169,4 +169,4 @@ def sinusoidal_embedding_1d(dim, position):
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if GET_DTYPE() == "BF16": if GET_DTYPE() == "BF16":
x = x.to(torch.bfloat16) x = x.to(torch.bfloat16)
return x return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment