Unverified Commit b50498fa authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Add lightx2v_platform (#541)

parent 31da6925
from .flash_attn import *
from .sage_attn import *
import math
from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
@PLATFORM_ATTN_WEIGHT_REGISTER("mlu_flash_attn")
class MluFlashAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
assert tmo is not None, "torch_mlu_ops is not installed."
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.flash_attention(
q=q,
k=k,
v=v,
cu_seq_lens_q=cu_seqlens_q,
cu_seq_lens_kv=cu_seqlens_kv,
max_seq_len_q=max_seqlen_q,
max_seq_len_kv=max_seqlen_kv,
softmax_scale=softmax_scale,
return_lse=False,
out_dtype=q.dtype,
is_causal=False,
out=None,
alibi_slope=None,
attn_bias=None,
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
import math
import torch
from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
@PLATFORM_ATTN_WEIGHT_REGISTER("mlu_sage_attn")
class MluSageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
assert tmo is not None, "torch_mlu_ops is not installed."
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = tmo.sage_attn(
q=q, k=k, v=v, cu_seq_lens_q=None, cu_seq_lens_kv=None, max_seq_len_kv=max_seqlen_kv, max_seq_len_q=max_seqlen_q, is_causal=False, compute_dtype=torch.bfloat16, softmax_scale=softmax_scale
)
x = x.reshape(bs * max_seqlen_q, -1)
return x
from abc import ABCMeta, abstractmethod
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self, non_blocking=False):
pass
def to_cuda(self, non_blocking=False):
pass
def state_dict(self, destination=None):
if destination is None:
destination = {}
return destination
def load_state_dict(self, destination, block_index, adapter_block_inde=None):
return {}
from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate
from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
@PLATFORM_MM_WEIGHT_REGISTER("int8-tmo")
class MMWeightWint8channelAint8channeldynamicMlu(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Mlu
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: mlu
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_tmo
def act_quant_int8_perchannel_sym_tmo(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def apply(self, input_tensor):
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(
input_tensor_quant, self.weight.contiguous(), input_tensor_scale, self.weight_scale.squeeze(-1), bias=self.bias if self.bias is not None else None, output_dtype=dtype, use_hp_active=True
)
return output_tensor
import torch
import torch.nn as nn
try:
import torch_mlu_ops as tmo
except ImportError:
tmo = None
class MluQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
input_tensor_quant, input_tensor_scale = tmo.scaled_quantize(x)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
dtype = input_tensor.dtype
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = tmo.scaled_matmul(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.squeeze(-1), output_dtype=dtype)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
from abc import ABCMeta, abstractmethod
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name
self.bias_name = bias_name
self.create_cuda_buffer = create_cuda_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self):
pass
def set_config(self, config={}):
self.config = config
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.cuda(non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.cuda(non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale_name"):
self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None
self.weight_need_transpose = True
self.act_quant_func = None
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
class Register(dict):
def __init__(self, *args, **kwargs):
super(Register, self).__init__(*args, **kwargs)
self._dict = {}
def __call__(self, target_or_name):
if callable(target_or_name):
return self.register(target_or_name)
else:
return lambda x: self.register(x, key=target_or_name)
def register(self, target, key=None):
if not callable(target):
raise Exception(f"Error: {target} must be callable!")
if key is None:
key = target.__name__
if key in self._dict:
raise Exception(f"{key} already exists.")
self[key] = target
return target
def __setitem__(self, key, value):
self._dict[key] = value
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def __str__(self):
return str(self._dict)
def keys(self):
return self._dict.keys()
def values(self):
return self._dict.values()
def items(self):
return self._dict.items()
def get(self, key, default=None):
return self._dict.get(key, default)
def merge(self, other_register):
for key, value in other_register.items():
if key in self._dict:
raise Exception(f"{key} already exists in target register.")
self[key] = value
PLATFORM_DEVICE_REGISTER = Register()
PLATFORM_ATTN_WEIGHT_REGISTER = Register()
PLATFORM_MM_WEIGHT_REGISTER = Register()
import os
from lightx2v_platform import *
def set_ai_device():
platform = os.getenv("PLATFORM", "cuda")
init_ai_device(platform)
from lightx2v_platform.base.global_var import AI_DEVICE
check_ai_device(AI_DEVICE)
set_ai_device()
from lightx2v_platform.ops import * # noqa: E402
import os
from lightx2v_platform import *
init_ai_device(os.getenv("AI_DEVICE", "cuda"))
from lightx2v_platform.base.global_var import AI_DEVICE # noqa E402
if __name__ == "__main__":
print(f"AI_DEVICE : {AI_DEVICE}")
is_available = check_ai_device(AI_DEVICE)
print(f"Device available: {is_available}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment