Unverified Commit 1d1608b8 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update lightx2v_platform (#559)

parent 74eeb429
......@@ -18,8 +18,8 @@ class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
Kernel: mlu
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
......
......@@ -2,12 +2,15 @@ from abc import ABCMeta, abstractmethod
import torch
from lightx2v_platform.base.global_var import AI_DEVICE
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name
self.bias_name = bias_name
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
......@@ -25,11 +28,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
self.config = config
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
......@@ -47,8 +50,8 @@ class MMWeightTemplate(metaclass=ABCMeta):
class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None
self.weight_need_transpose = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment