Unverified Commit 1d1608b8 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update lightx2v_platform (#559)

parent 74eeb429
...@@ -18,8 +18,8 @@ class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate): ...@@ -18,8 +18,8 @@ class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
Kernel: mlu Kernel: mlu
""" """
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter) super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
......
...@@ -2,12 +2,15 @@ from abc import ABCMeta, abstractmethod ...@@ -2,12 +2,15 @@ from abc import ABCMeta, abstractmethod
import torch import torch
from lightx2v_platform.base.global_var import AI_DEVICE
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name self.weight_name = weight_name
self.bias_name = bias_name self.bias_name = bias_name
self.create_cuda_buffer = create_cuda_buffer self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter self.is_post_adapter = is_post_adapter
...@@ -25,11 +28,11 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -25,11 +28,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
self.config = config self.config = config
def to_cuda(self, non_blocking=False): def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking) self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"): if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking) self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None: if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking) self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"): if hasattr(self, "pin_weight"):
...@@ -47,8 +50,8 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -47,8 +50,8 @@ class MMWeightTemplate(metaclass=ABCMeta):
class MMWeightQuantTemplate(MMWeightTemplate): class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False): def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter) super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale" self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None self.load_func = None
self.weight_need_transpose = True self.weight_need_transpose = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment