Commit aec90a0d authored by helloyongyang's avatar helloyongyang
Browse files

update some variable names

parent c98d486d
......@@ -59,7 +59,7 @@ class Message(BaseModel):
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message, request: Request):
async def v1_local_video_generate(message: Message):
global runner
runner.set_inputs(message)
await asyncio.to_thread(runner.run_pipeline)
......
......@@ -87,9 +87,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_need_transpose = True
self.act_quant_func = None
"""
weight load functions
"""
# =========================
# weight load functions
# =========================
def load(self, weight_dict):
self.load_func(weight_dict)
......@@ -140,9 +140,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
"""
act quant kernels
"""
# =========================
# act quant kernels
# =========================
def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
......
......@@ -15,17 +15,17 @@ class HunyuanPostWeights:
self.final_layer_adaLN_modulation_1,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cuda()
......@@ -97,17 +97,17 @@ class HunyuanPreWeights:
self.guidance_in_mlp_2,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
weight.to_cuda()
......@@ -79,30 +79,30 @@ class HunyuanTransformerDoubleBlock:
self.txt_mlp_fc2,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda(non_blocking=True)
class HunyuanTransformerSingleBlock:
......@@ -131,27 +131,27 @@ class HunyuanTransformerSingleBlock:
self.modulation,
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
weight.to_cuda(non_blocking=True)
......@@ -12,22 +12,22 @@ class WanPostWeights:
self.weight_list = [self.head]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
weight.to_cpu()
self.head_modulation = self.head_modulation.cpu()
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cpu()
self.head_modulation = self.head_modulation.cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, MMWeightTemplate):
weight.to_cuda()
self.head_modulation = self.head_modulation.cuda()
......@@ -40,19 +40,19 @@ class WanPreWeights:
self.weight_list.append(self.proj_3)
self.weight_list.append(self.proj_4)
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
weight.to_cpu()
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
weight.to_cuda()
......@@ -81,31 +81,31 @@ class WanTransformerAttentionBlock:
self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img)
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.set_config(self.config["mm_config"])
weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cpu()
self.modulation = self.modulation.cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cuda()
self.modulation = self.modulation.cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cpu(non_blocking=True)
self.modulation = self.modulation.to("cpu", non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
for weight in self.weight_list:
if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
weight.to_cuda(non_blocking=True)
self.modulation = self.modulation.cuda(non_blocking=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment