Commit aec90a0d authored by helloyongyang's avatar helloyongyang
Browse files

update some variable names

parent c98d486d
...@@ -59,7 +59,7 @@ class Message(BaseModel): ...@@ -59,7 +59,7 @@ class Message(BaseModel):
@app.post("/v1/local/video/generate") @app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message, request: Request): async def v1_local_video_generate(message: Message):
global runner global runner
runner.set_inputs(message) runner.set_inputs(message)
await asyncio.to_thread(runner.run_pipeline) await asyncio.to_thread(runner.run_pipeline)
......
...@@ -87,9 +87,9 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -87,9 +87,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_need_transpose = True self.weight_need_transpose = True
self.act_quant_func = None self.act_quant_func = None
""" # =========================
weight load functions # weight load functions
""" # =========================
def load(self, weight_dict): def load(self, weight_dict):
self.load_func(weight_dict) self.load_func(weight_dict)
...@@ -140,9 +140,9 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -140,9 +140,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
""" # =========================
act quant kernels # act quant kernels
""" # =========================
def act_quant_fp8_perchannel_sym_vllm(self, x): def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
......
...@@ -15,17 +15,17 @@ class HunyuanPostWeights: ...@@ -15,17 +15,17 @@ class HunyuanPostWeights:
self.final_layer_adaLN_modulation_1, self.final_layer_adaLN_modulation_1,
] ]
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(weight, MMWeightTemplate):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(weight, MMWeightTemplate):
mm_weight.to_cpu() weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(weight, MMWeightTemplate):
mm_weight.to_cuda() weight.to_cuda()
...@@ -97,17 +97,17 @@ class HunyuanPreWeights: ...@@ -97,17 +97,17 @@ class HunyuanPreWeights:
self.guidance_in_mlp_2, self.guidance_in_mlp_2,
] ]
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate): if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate): if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
mm_weight.to_cpu() weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, LNWeightTemplate) or isinstance(mm_weight, Conv3dWeightTemplate): if isinstance(weight, MMWeightTemplate) or isinstance(weight, LNWeightTemplate) or isinstance(weight, Conv3dWeightTemplate):
mm_weight.to_cuda() weight.to_cuda()
...@@ -79,30 +79,30 @@ class HunyuanTransformerDoubleBlock: ...@@ -79,30 +79,30 @@ class HunyuanTransformerDoubleBlock:
self.txt_mlp_fc2, self.txt_mlp_fc2,
] ]
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu() weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda() weight.to_cuda()
def to_cpu_sync(self): def to_cpu_sync(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True) weight.to_cpu(non_blocking=True)
def to_cuda_sync(self): def to_cuda_sync(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True) weight.to_cuda(non_blocking=True)
class HunyuanTransformerSingleBlock: class HunyuanTransformerSingleBlock:
...@@ -131,27 +131,27 @@ class HunyuanTransformerSingleBlock: ...@@ -131,27 +131,27 @@ class HunyuanTransformerSingleBlock:
self.modulation, self.modulation,
] ]
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu() weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda() weight.to_cuda()
def to_cpu_sync(self): def to_cpu_sync(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True) weight.to_cpu(non_blocking=True)
def to_cuda_sync(self): def to_cuda_sync(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True) weight.to_cuda(non_blocking=True)
...@@ -12,22 +12,22 @@ class WanPostWeights: ...@@ -12,22 +12,22 @@ class WanPostWeights:
self.weight_list = [self.head] self.weight_list = [self.head]
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(weight, MMWeightTemplate):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
mm_weight.to_cpu() weight.to_cpu()
self.head_modulation = self.head_modulation.cpu() self.head_modulation = self.head_modulation.cpu()
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(weight, MMWeightTemplate):
mm_weight.to_cpu() weight.to_cpu()
self.head_modulation = self.head_modulation.cpu() self.head_modulation = self.head_modulation.cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate): if isinstance(weight, MMWeightTemplate):
mm_weight.to_cuda() weight.to_cuda()
self.head_modulation = self.head_modulation.cuda() self.head_modulation = self.head_modulation.cuda()
...@@ -40,19 +40,19 @@ class WanPreWeights: ...@@ -40,19 +40,19 @@ class WanPreWeights:
self.weight_list.append(self.proj_3) self.weight_list.append(self.proj_3)
self.weight_list.append(self.proj_4) self.weight_list.append(self.proj_4)
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
mm_weight.to_cpu() weight.to_cpu()
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cpu() weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, Conv3dWeightTemplate)):
mm_weight.to_cuda() weight.to_cuda()
...@@ -81,31 +81,31 @@ class WanTransformerAttentionBlock: ...@@ -81,31 +81,31 @@ class WanTransformerAttentionBlock:
self.weight_list.append(self.cross_attn_v_img) self.weight_list.append(self.cross_attn_v_img)
self.weight_list.append(self.cross_attn_norm_k_img) self.weight_list.append(self.cross_attn_norm_k_img)
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu() weight.to_cpu()
self.modulation = self.modulation.cpu() self.modulation = self.modulation.cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda() weight.to_cuda()
self.modulation = self.modulation.cuda() self.modulation = self.modulation.cuda()
def to_cpu_sync(self): def to_cpu_sync(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True) weight.to_cpu(non_blocking=True)
self.modulation = self.modulation.to("cpu", non_blocking=True) self.modulation = self.modulation.to("cpu", non_blocking=True)
def to_cuda_sync(self): def to_cuda_sync(self):
for mm_weight in self.weight_list: for weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True) weight.to_cuda(non_blocking=True)
self.modulation = self.modulation.cuda(non_blocking=True) self.modulation = self.modulation.cuda(non_blocking=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment