Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
from .full_attn import *
from .windowed_attn import *
from .modules import *
from typing import *
import torch
from .. import VarLenTensor
from .. import config
__all__ = [
'sparse_scaled_dot_product_attention',
]
@overload
def sparse_scaled_dot_product_attention(qkv: VarLenTensor) -> VarLenTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
qkv (VarLenTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: VarLenTensor, kv: Union[VarLenTensor, torch.Tensor]) -> VarLenTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (VarLenTensor): A [N, *, H, C] sparse tensor containing Qs.
kv (VarLenTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: VarLenTensor) -> torch.Tensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (torch.Tensor): A [N, L, H, C] dense tensor containing Qs.
kv (VarLenTensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: VarLenTensor, k: VarLenTensor, v: VarLenTensor) -> VarLenTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs.
k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks.
v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs.
Note:
k and v are assumed to have the same coordinate map.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: VarLenTensor, k: torch.Tensor, v: torch.Tensor) -> VarLenTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Qs.
k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: torch.Tensor, k: VarLenTensor, v: VarLenTensor) -> torch.Tensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
k (VarLenTensor): A [N, *, H, Ci] sparse tensor containing Ks.
v (VarLenTensor): A [N, *, H, Co] sparse tensor containing Vs.
"""
...
def sparse_scaled_dot_product_attention(*args, **kwargs):
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
assert isinstance(qkv, VarLenTensor), f"qkv must be a VarLenTensor, got {type(qkv)}"
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
device = qkv.device
s = qkv
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
kv_seqlen = q_seqlen
qkv = qkv.feats # [T, 3, H, C]
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
assert isinstance(q, VarLenTensor) and isinstance(kv, (VarLenTensor, torch.Tensor)) or \
isinstance(q, torch.Tensor) and isinstance(kv, VarLenTensor), \
f"Invalid types, got {type(q)} and {type(kv)}"
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
device = q.device
if isinstance(q, VarLenTensor):
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, C]
else:
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
s = None
N, L, H, C = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, C) # [T_Q, H, C]
if isinstance(kv, VarLenTensor):
assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
kv = kv.feats # [T_KV, 2, H, C]
else:
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
N, L, _, H, C = kv.shape
kv_seqlen = [L] * N
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
assert isinstance(q, VarLenTensor) and isinstance(k, (VarLenTensor, torch.Tensor)) and type(k) == type(v) or \
isinstance(q, torch.Tensor) and isinstance(k, VarLenTensor) and isinstance(v, VarLenTensor), \
f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
device = q.device
if isinstance(q, VarLenTensor):
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, Ci]
else:
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
s = None
N, L, H, CI = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
if isinstance(k, VarLenTensor):
assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
k = k.feats # [T_KV, H, Ci]
v = v.feats # [T_KV, H, Co]
else:
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
N, L, H, CI, CO = *k.shape, v.shape[-1]
kv_seqlen = [L] * N
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
if config.ATTN == 'xformers':
if 'xops' not in globals():
import xformers.ops as xops
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
out = xops.memory_efficient_attention(q, k, v, mask)[0]
elif config.ATTN == 'flash_attn':
if 'flash_attn' not in globals():
import flash_attn
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
elif num_all_args == 2:
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif config.ATTN == 'flash_attn_3':
if 'flash_attn_3' not in globals():
import flash_attn_interface as flash_attn_3
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
cu_seqlens_kv = cu_seqlens_q.clone()
max_q_seqlen = max_kv_seqlen = max(q_seqlen)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
max_q_seqlen = max(q_seqlen)
max_kv_seqlen = max(kv_seqlen)
elif num_all_args == 3:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
max_q_seqlen = max(q_seqlen)
max_kv_seqlen = max(kv_seqlen)
out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
elif config.ATTN == 'sdpa':
import torch.nn.functional as F
# --- Step 1: Unpack q, k, v based on input format ---
if num_all_args == 1:
# Packed qkv format: [T, 3, H, C]
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
# q and packed kv format: [T_KV, 2, H, C]
k, v = kv.unbind(dim=1)
# q already set in arg parsing above
# else num_all_args == 3: q, k, v already set in arg parsing
# --- Step 2: Extract shapes ---
num_heads = q.shape[1]
head_dim = q.shape[2]
B = len(q_seqlen)
max_q_len = max(q_seqlen)
max_kv_len = max(kv_seqlen)
# --- Step 3: Pad to dense [B, N, H, C] ---
q_padded = torch.zeros(B, max_q_len, num_heads, head_dim,
device=device, dtype=q.dtype)
k_padded = torch.zeros(B, max_kv_len, num_heads, head_dim,
device=device, dtype=k.dtype)
v_padded = torch.zeros(B, max_kv_len, num_heads, head_dim,
device=device, dtype=v.dtype)
q_off, kv_off = 0, 0
for b in range(B):
ql, kvl = q_seqlen[b], kv_seqlen[b]
q_padded[b, :ql] = q[q_off:q_off + ql]
k_padded[b, :kvl] = k[kv_off:kv_off + kvl]
v_padded[b, :kvl] = v[kv_off:kv_off + kvl]
q_off += ql
kv_off += kvl
# --- Step 4: Transpose for SDPA [B, H, N, C] ---
q_t = q_padded.transpose(1, 2)
k_t = k_padded.transpose(1, 2)
v_t = v_padded.transpose(1, 2)
# --- Step 5: Create mask [B, 1, N_q, N_kv] ---
attn_mask = torch.zeros(B, max_q_len, max_kv_len,
dtype=torch.bool, device=device)
for b in range(B):
attn_mask[b, :q_seqlen[b], :kv_seqlen[b]] = True
attn_mask = attn_mask.unsqueeze(1)
# --- Step 6: Run SDPA ---
# Use torch.nn.attention.sdpa_kernel for newer PyTorch
try:
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel([SDPBackend.MATH]):
out_t = F.scaled_dot_product_attention(
q_t, k_t, v_t,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
)
except ImportError:
# Fallback for older PyTorch
with torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=True,
):
out_t = F.scaled_dot_product_attention(
q_t, k_t, v_t,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=False,
)
# --- Step 7: Transpose and unpad ---
out_padded = out_t.transpose(1, 2) # [B, N, H, C]
out = torch.zeros(q.shape[0], num_heads, head_dim,
device=device, dtype=q.dtype)
q_off = 0
for b in range(B):
ql = q_seqlen[b]
out[q_off:q_off + ql] = out_padded[b, :ql]
q_off += ql
# NO EARLY RETURN HERE - let it fall through to common return
# The code below (outside all elif blocks) handles:
# if s is not None: return s.replace(out)
# else: return out.reshape(...)
else:
raise ValueError(f"Unknown attention module: {config.ATTN}")
if s is not None:
return s.replace(out)
else:
return out.reshape(N, L, H, -1)
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from .. import VarLenTensor, SparseTensor
from ..linear import rocm_safe_linear
from .full_attn import sparse_scaled_dot_product_attention
from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
from .rope import SparseRotaryPositionEmbedder
class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim))
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
x_type = x.dtype
x = x.float()
if isinstance(x, VarLenTensor):
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
else:
x = F.normalize(x, dim=-1) * self.gamma * self.scale
return x.to(x_type)
class SparseMultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
ctx_channels: Optional[int] = None,
type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
):
super().__init__()
assert channels % num_heads == 0
assert type in ["self", "cross"], f"Invalid attention type: {type}"
assert attn_mode in ["full", "windowed", "double_windowed"], f"Invalid attention mode: {attn_mode}"
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
if attn_mode == 'double_windowed':
assert window_size % 2 == 0, "Window size must be even for double windowed attention"
assert num_heads % 2 == 0, "Number of heads must be even for double windowed attention"
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
else:
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
if self.qk_rms_norm:
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
self.to_out = nn.Linear(channels, channels)
if use_rope:
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
@staticmethod
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
return x.replace(rocm_safe_linear(x.feats, module.weight, module.bias))
else:
return module(x)
@staticmethod
def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
return x.reshape(*shape)
else:
return x.reshape(*x.shape[:2], *shape)
def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
x_feats = x.feats.unsqueeze(0)
else:
x_feats = x
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
if self._type == "self":
qkv = self._linear(self.to_qkv, x)
qkv = self._fused_pre(qkv, num_fused=3)
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
q, k = self.rope(q, k)
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
if self.attn_mode == "full":
h = sparse_scaled_dot_product_attention(qkv)
elif self.attn_mode == "windowed":
h = sparse_windowed_scaled_dot_product_self_attention(
qkv, self.window_size, shift_window=self.shift_window
)
elif self.attn_mode == "double_windowed":
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
h0 = sparse_windowed_scaled_dot_product_self_attention(
qkv0, self.window_size, shift_window=(0, 0, 0)
)
h1 = sparse_windowed_scaled_dot_product_self_attention(
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
)
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
else:
q = self._linear(self.to_q, x)
q = self._reshape_chs(q, (self.num_heads, -1))
kv = self._linear(self.to_kv, context)
kv = self._fused_pre(kv, num_fused=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=-3)
k = self.k_rms_norm(k)
h = sparse_scaled_dot_product_attention(q, k, v)
else:
h = sparse_scaled_dot_product_attention(q, kv)
h = self._reshape_chs(h, (-1,))
h = self._linear(self.to_out, h)
return h
from typing import *
import torch
import torch.nn as nn
from ..basic import SparseTensor
class SparseRotaryPositionEmbedder(nn.Module):
def __init__(
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0)
):
super().__init__()
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
self.freq_dim = head_dim // 2 // dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
q (SparseTensor): [..., N, H, D] tensor of queries
k (SparseTensor): [..., N, H, D] tensor of keys
"""
assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
phases = q.get_spatial_cache(phases_cache_name)
if phases is None:
coords = q.coords[..., 1:]
phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device),
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
)], dim=-1)
q.register_spatial_cache(phases_cache_name, phases)
q_embed = q.replace(self._rotary_embedding(q.feats, phases))
if k is None:
return q_embed
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
return q_embed, k_embed
\ No newline at end of file
from typing import *
import torch
import math
from .. import SparseTensor
from .. import config
__all__ = [
'sparse_windowed_scaled_dot_product_self_attention',
'sparse_windowed_scaled_dot_product_cross_attention',
]
def calc_window_partition(
tensor: SparseTensor,
window_size: Union[int, Tuple[int, ...]],
shift_window: Union[int, Tuple[int, ...]] = 0,
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
"""
Calculate serialization and partitioning for a set of coordinates.
Args:
tensor (SparseTensor): The input tensor.
window_size (int): The window size to use.
shift_window (Tuple[int, ...]): The shift of serialized coordinates.
Returns:
(torch.Tensor): Forwards indices.
(torch.Tensor): Backwards indices.
(torch.Tensor): Sequence lengths.
(dict): Attn func args.
"""
DIM = tensor.coords.shape[1] - 1
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
shifted_coords = tensor.coords.clone().detach()
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
fwd_indices = torch.argsort(shifted_indices)
bwd_indices = torch.empty_like(fwd_indices)
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
seq_lens = torch.bincount(shifted_indices)
mask = seq_lens != 0
seq_lens = seq_lens[mask]
if config.ATTN == 'xformers':
if 'xops' not in globals():
import xformers.ops as xops
attn_func_args = {
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
}
elif config.ATTN == 'flash_attn':
attn_func_args = {
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
'max_seqlen': torch.max(seq_lens)
}
return fwd_indices, bwd_indices, seq_lens, attn_func_args
def sparse_windowed_scaled_dot_product_self_attention(
qkv: SparseTensor,
window_size: int,
shift_window: Tuple[int, int, int] = (0, 0, 0)
) -> SparseTensor:
"""
Apply windowed scaled dot product self attention to a sparse tensor.
Args:
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
window_size (int): The window size to use.
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
Returns:
(SparseTensor): [N, *, H, C] sparse tensor containing the output features.
"""
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
else:
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
if config.DEBUG:
start = 0
qkv_coords = qkv.coords[fwd_indices]
for i in range(len(seq_lens)):
seq_coords = qkv_coords[start:start+seq_lens[i]]
assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
start += seq_lens[i]
if config.ATTN == 'xformers':
if 'xops' not in globals():
import xformers.ops as xops
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
elif config.ATTN == 'flash_attn':
if 'flash_attn' not in globals():
import flash_attn
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
out = out[bwd_indices] # [T, H, C]
if config.DEBUG:
qkv_coords = qkv_coords[bwd_indices]
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
return qkv.replace(out)
def sparse_windowed_scaled_dot_product_cross_attention(
q: SparseTensor,
kv: SparseTensor,
q_window_size: int,
kv_window_size: int,
q_shift_window: Tuple[int, int, int] = (0, 0, 0),
kv_shift_window: Tuple[int, int, int] = (0, 0, 0),
) -> SparseTensor:
"""
Apply windowed scaled dot product cross attention to two sparse tensors.
Args:
q (SparseTensor): [N, *, H, C] sparse tensor containing Qs.
kv (SparseTensor): [N, *, 2, H, C] sparse tensor containing Ks and Vs.
q_window_size (int): The window size to use for Qs.
kv_window_size (int): The window size to use for Ks and Vs.
q_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Qs.
kv_shift_window (Tuple[int, int, int]): The shift of serialized coordinates for Ks and Vs.
Returns:
(SparseTensor): [N, *, H, C] sparse tensor containing the output features.
"""
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
q_serialization_spatial_cache_name = f'windowed_attention_{q_window_size}_{q_shift_window}'
q_serialization_spatial_cache = q.get_spatial_cache(q_serialization_spatial_cache_name)
if q_serialization_spatial_cache is None:
q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = calc_window_partition(q, q_window_size, q_shift_window)
q.register_spatial_cache(q_serialization_spatial_cache_name, (q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args))
else:
q_fwd_indices, q_bwd_indices, q_seq_lens, q_attn_func_args = q_serialization_spatial_cache
kv_serialization_spatial_cache_name = f'windowed_attention_{kv_window_size}_{kv_shift_window}'
kv_serialization_spatial_cache = kv.get_spatial_cache(kv_serialization_spatial_cache_name)
if kv_serialization_spatial_cache is None:
kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = calc_window_partition(kv, kv_window_size, kv_shift_window)
kv.register_spatial_cache(kv_serialization_spatial_cache_name, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args))
else:
kv_fwd_indices, kv_bwd_indices, kv_seq_lens, kv_attn_func_args = kv_serialization_spatial_cache
assert len(q_seq_lens) == len(kv_seq_lens), "Number of sequences in q and kv must match"
q_feats = q.feats[q_fwd_indices] # [M, H, C]
kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C]
if config.ATTN == 'xformers':
if 'xops' not in globals():
import xformers.ops as xops
k, v = kv_feats.unbind(dim=1) # [M, H, C]
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens)
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)[0] # [M, H, C]
elif config.ATTN == 'flash_attn':
if 'flash_attn' not in globals():
import flash_attn
out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats, kv_feats,
cu_seqlens_q=q_attn_func_args['cu_seqlens'], cu_seqlens_k=kv_attn_func_args['cu_seqlens'],
max_seqlen_q=q_attn_func_args['max_seqlen'], max_seqlen_k=kv_attn_func_args['max_seqlen'],
) # [M, H, C]
out = out[q_bwd_indices] # [T, H, C]
return q.replace(out)
from typing import *
from fractions import Fraction
import torch
from . import config
__all__ = [
'VarLenTensor',
'varlen_cat',
'varlen_unbind',
'SparseTensor',
'sparse_cat',
'sparse_unbind',
]
class VarLenTensor:
"""
Sequential tensor with variable length.
Args:
feats (torch.Tensor): Features of the varlen tensor.
layout (List[slice]): Layout of the varlen tensor for each batch
"""
def __init__(self, feats: torch.Tensor, layout: List[slice]=None):
self.feats = feats
self.layout = layout if layout is not None else [slice(0, feats.shape[0])]
self._cache = {}
@staticmethod
def layout_from_seqlen(seqlen: list) -> List[slice]:
"""
Create a layout from a tensor of sequence lengths.
"""
layout = []
start = 0
for l in seqlen:
layout.append(slice(start, start + l))
start += l
return layout
@staticmethod
def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor':
"""
Create a VarLenTensor from a list of tensors.
"""
feats = torch.cat(tensor_list, dim=0)
layout = []
start = 0
for tensor in tensor_list:
layout.append(slice(start, start + tensor.shape[0]))
start += tensor.shape[0]
return VarLenTensor(feats, layout)
def to_tensor_list(self) -> List[torch.Tensor]:
"""
Convert a VarLenTensor to a list of tensors.
"""
tensor_list = []
for s in self.layout:
tensor_list.append(self.feats[s])
return tensor_list
def __len__(self) -> int:
return len(self.layout)
@property
def shape(self) -> torch.Size:
return torch.Size([len(self.layout), *self.feats.shape[1:]])
def dim(self) -> int:
return len(self.shape)
@property
def ndim(self) -> int:
return self.dim()
@property
def dtype(self):
return self.feats.dtype
@property
def device(self):
return self.feats.device
@property
def seqlen(self) -> torch.LongTensor:
if 'seqlen' not in self._cache:
self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
return self._cache['seqlen']
@property
def cum_seqlen(self) -> torch.LongTensor:
if 'cum_seqlen' not in self._cache:
self._cache['cum_seqlen'] = torch.cat([
torch.tensor([0], dtype=torch.long, device=self.device),
self.seqlen.cumsum(dim=0)
], dim=0)
return self._cache['cum_seqlen']
@property
def batch_boardcast_map(self) -> torch.LongTensor:
"""
Get the broadcast map for the varlen tensor.
"""
if 'batch_boardcast_map' not in self._cache:
self._cache['batch_boardcast_map'] = torch.repeat_interleave(
torch.arange(len(self.layout), device=self.device),
self.seqlen,
)
return self._cache['batch_boardcast_map']
@overload
def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
@overload
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ...
def to(self, *args, **kwargs) -> 'VarLenTensor':
device = None
dtype = None
if len(args) == 2:
device, dtype = args
elif len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype = args[0]
else:
device = args[0]
if 'dtype' in kwargs:
assert dtype is None, "to() received multiple values for argument 'dtype'"
dtype = kwargs['dtype']
if 'device' in kwargs:
assert device is None, "to() received multiple values for argument 'device'"
device = kwargs['device']
non_blocking = kwargs.get('non_blocking', False)
copy = kwargs.get('copy', False)
new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
return self.replace(new_feats)
def type(self, dtype):
new_feats = self.feats.type(dtype)
return self.replace(new_feats)
def cpu(self) -> 'VarLenTensor':
new_feats = self.feats.cpu()
return self.replace(new_feats)
def cuda(self) -> 'VarLenTensor':
new_feats = self.feats.cuda()
return self.replace(new_feats)
def half(self) -> 'VarLenTensor':
new_feats = self.feats.half()
return self.replace(new_feats)
def float(self) -> 'VarLenTensor':
new_feats = self.feats.float()
return self.replace(new_feats)
def detach(self) -> 'VarLenTensor':
new_feats = self.feats.detach()
return self.replace(new_feats)
def reshape(self, *shape) -> 'VarLenTensor':
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
return self.replace(new_feats)
def unbind(self, dim: int) -> List['VarLenTensor']:
return varlen_unbind(self, dim)
def replace(self, feats: torch.Tensor) -> 'VarLenTensor':
new_tensor = VarLenTensor(
feats=feats,
layout=self.layout,
)
new_tensor._cache = self._cache
return new_tensor
def to_dense(self, max_length=None) -> torch.Tensor:
"""
Convert a VarLenTensor to a dense representation without for-loop.
Returns:
dense (torch.Tensor): (N, L, C) dense tensor
mask (torch.BoolTensor): (N, L) mask indicating valid positions
"""
N = len(self)
L = max_length or self.seqlen.max().item()
spatial = self.feats.shape[1:]
idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L)
mask = (idx < self.seqlen.unsqueeze(1))
mapping = mask.reshape(-1).cumsum(dim=0) - 1
dense = self.feats[mapping]
dense = dense.reshape(N, L, *spatial)
return dense, mask
def __neg__(self) -> 'VarLenTensor':
return self.replace(-self.feats)
def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor':
if isinstance(other, torch.Tensor):
try:
other = torch.broadcast_to(other, self.shape)
other = other[self.batch_boardcast_map]
except:
pass
if isinstance(other, VarLenTensor):
other = other.feats
new_feats = op(self.feats, other)
new_tensor = self.replace(new_feats)
return new_tensor
def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, torch.add)
def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, torch.add)
def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, torch.sub)
def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, torch.mul)
def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, torch.mul)
def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, torch.div)
def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor':
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
def __getitem__(self, idx):
if isinstance(idx, int):
idx = [idx]
elif isinstance(idx, slice):
idx = range(*idx.indices(self.shape[0]))
elif isinstance(idx, list):
assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
elif isinstance(idx, torch.Tensor):
if idx.dtype == torch.bool:
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
idx = idx.nonzero().squeeze(1)
elif idx.dtype in [torch.int32, torch.int64]:
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
else:
raise ValueError(f"Unknown index type: {idx.dtype}")
else:
raise ValueError(f"Unknown index type: {type(idx)}")
new_feats = []
new_layout = []
start = 0
for new_idx, old_idx in enumerate(idx):
new_feats.append(self.feats[self.layout[old_idx]])
new_layout.append(slice(start, start + len(new_feats[-1])))
start += len(new_feats[-1])
new_feats = torch.cat(new_feats, dim=0).contiguous()
new_tensor = VarLenTensor(feats=new_feats, layout=new_layout)
return new_tensor
def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
if isinstance(dim, int):
dim = (dim,)
if op =='mean':
red = self.feats.mean(dim=dim, keepdim=keepdim)
elif op =='sum':
red = self.feats.sum(dim=dim, keepdim=keepdim)
elif op == 'prod':
red = self.feats.prod(dim=dim, keepdim=keepdim)
else:
raise ValueError(f"Unsupported reduce operation: {op}")
if dim is None or 0 in dim:
return red
red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen)
return red
def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='mean', dim=dim, keepdim=keepdim)
def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='sum', dim=dim, keepdim=keepdim)
def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
return self.reduce(op='prod', dim=dim, keepdim=keepdim)
def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor:
mean = self.mean(dim=dim, keepdim=True)
mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True)
std = (mean2 - mean ** 2).sqrt()
return std
def __repr__(self) -> str:
return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor:
"""
Concatenate a list of varlen tensors.
Args:
inputs (List[VarLenTensor]): List of varlen tensors to concatenate.
"""
if dim == 0:
new_feats = torch.cat([input.feats for input in inputs], dim=0)
start = 0
new_layout = []
for input in inputs:
for l in input.layout:
new_layout.append(slice(start, start + l.stop - l.start))
start += l.stop - l.start
output = VarLenTensor(feats=new_feats, layout=new_layout)
else:
feats = torch.cat([input.feats for input in inputs], dim=dim)
output = inputs[0].replace(feats)
return output
def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]:
"""
Unbind a varlen tensor along a dimension.
Args:
input (VarLenTensor): Varlen tensor to unbind.
dim (int): Dimension to unbind.
"""
if dim == 0:
return [input[i] for i in range(len(input))]
else:
feats = input.feats.unbind(dim)
return [input.replace(f) for f in feats]
class SparseTensor(VarLenTensor):
"""
Sparse tensor with support for both torchsparse and spconv backends.
Parameters:
- feats (torch.Tensor): Features of the sparse tensor.
- coords (torch.Tensor): Coordinates of the sparse tensor.
- shape (torch.Size): Shape of the sparse tensor.
- layout (List[slice]): Layout of the sparse tensor for each batch
- data (SparseTensorData): Sparse tensor data used for convolusion
NOTE:
- Data corresponding to a same batch should be contiguous.
- Coords should be in [0, 1023]
"""
SparseTensorData = None
@overload
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
@overload
def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
def __init__(self, *args, **kwargs):
# Lazy import of sparse tensor backend
if self.SparseTensorData is None:
import importlib
if config.CONV == 'torchsparse':
self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
elif config.CONV == 'spconv':
self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
method_id = 0
if len(args) != 0:
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
else:
method_id = 1 if 'data' in kwargs else 0
if method_id == 0:
feats, coords, shape = args + (None,) * (3 - len(args))
if 'feats' in kwargs:
feats = kwargs['feats']
del kwargs['feats']
if 'coords' in kwargs:
coords = kwargs['coords']
del kwargs['coords']
if 'shape' in kwargs:
shape = kwargs['shape']
del kwargs['shape']
if config.CONV == 'torchsparse':
self.data = self.SparseTensorData(feats, coords, **kwargs)
elif config.CONV == 'spconv':
spatial_shape = list(coords.max(0)[0] + 1)
self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
self.data._features = feats
else:
self.data = {
'feats': feats,
'coords': coords,
}
elif method_id == 1:
data, shape = args + (None,) * (2 - len(args))
if 'data' in kwargs:
data = kwargs['data']
del kwargs['data']
if 'shape' in kwargs:
shape = kwargs['shape']
del kwargs['shape']
self.data = data
self._shape = shape
self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1)))
self._spatial_cache = kwargs.get('spatial_cache', {})
if config.DEBUG:
try:
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
for i in range(self.shape[0]):
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
except Exception as e:
print('Debugging information:')
print(f"- Shape: {self.shape}")
print(f"- Layout: {self.layout}")
print(f"- Scale: {self._scale}")
print(f"- Coords: {self.coords}")
raise e
@staticmethod
def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor':
"""
Create a SparseTensor from a list of tensors.
"""
feats = torch.cat(feats_list, dim=0)
coords = []
for i, coord in enumerate(coords_list):
coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1)
coords.append(coord)
coords = torch.cat(coords, dim=0)
return SparseTensor(feats, coords)
def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Convert a SparseTensor to list of tensors.
"""
feats_list = []
coords_list = []
for s in self.layout:
feats_list.append(self.feats[s])
coords_list.append(self.coords[s])
return feats_list, coords_list
def __len__(self) -> int:
return len(self.layout)
def __cal_shape(self, feats, coords):
shape = []
shape.append(coords[:, 0].contiguous().max().item() + 1)
shape.extend([*feats.shape[1:]])
return torch.Size(shape)
def __cal_layout(self, coords, batch_size):
seq_len = torch.bincount(coords[:, 0].contiguous(), minlength=batch_size)
offset = torch.cumsum(seq_len, dim=0)
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
return layout
def __cal_spatial_shape(self, coords):
return torch.Size((coords[:, 1:].contiguous().max(0)[0] + 1).tolist())
@property
def shape(self) -> torch.Size:
if self._shape is None:
self._shape = self.__cal_shape(self.feats, self.coords)
return self._shape
@property
def layout(self) -> List[slice]:
layout = self.get_spatial_cache('layout')
if layout is None:
layout = self.__cal_layout(self.coords, self.shape[0])
self.register_spatial_cache('layout', layout)
return layout
@property
def spatial_shape(self) -> torch.Size:
spatial_shape = self.get_spatial_cache('shape')
if spatial_shape is None:
spatial_shape = self.__cal_spatial_shape(self.coords)
self.register_spatial_cache('shape', spatial_shape)
return spatial_shape
@property
def feats(self) -> torch.Tensor:
if config.CONV == 'torchsparse':
return self.data.F
elif config.CONV == 'spconv':
return self.data.features
else:
return self.data['feats']
@feats.setter
def feats(self, value: torch.Tensor):
if config.CONV == 'torchsparse':
self.data.F = value
elif config.CONV == 'spconv':
self.data.features = value
else:
self.data['feats'] = value
@property
def coords(self) -> torch.Tensor:
if config.CONV == 'torchsparse':
return self.data.C
elif config.CONV == 'spconv':
return self.data.indices
else:
return self.data['coords']
@coords.setter
def coords(self, value: torch.Tensor):
if config.CONV == 'torchsparse':
self.data.C = value
elif config.CONV == 'spconv':
self.data.indices = value
else:
self.data['coords'] = value
@property
def dtype(self):
return self.feats.dtype
@property
def device(self):
return self.feats.device
@property
def seqlen(self) -> torch.LongTensor:
seqlen = self.get_spatial_cache('seqlen')
if seqlen is None:
seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device)
self.register_spatial_cache('seqlen', seqlen)
return seqlen
@property
def cum_seqlen(self) -> torch.LongTensor:
cum_seqlen = self.get_spatial_cache('cum_seqlen')
if cum_seqlen is None:
cum_seqlen = torch.cat([
torch.tensor([0], dtype=torch.long, device=self.device),
self.seqlen.cumsum(dim=0)
], dim=0)
self.register_spatial_cache('cum_seqlen', cum_seqlen)
return cum_seqlen
@property
def batch_boardcast_map(self) -> torch.LongTensor:
"""
Get the broadcast map for the varlen tensor.
"""
batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map')
if batch_boardcast_map is None:
batch_boardcast_map = torch.repeat_interleave(
torch.arange(len(self.layout), device=self.device),
self.seqlen,
)
self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map)
return batch_boardcast_map
@overload
def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
@overload
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ...
def to(self, *args, **kwargs) -> 'SparseTensor':
device = None
dtype = None
if len(args) == 2:
device, dtype = args
elif len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype = args[0]
else:
device = args[0]
if 'dtype' in kwargs:
assert dtype is None, "to() received multiple values for argument 'dtype'"
dtype = kwargs['dtype']
if 'device' in kwargs:
assert device is None, "to() received multiple values for argument 'device'"
device = kwargs['device']
non_blocking = kwargs.get('non_blocking', False)
copy = kwargs.get('copy', False)
new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy)
new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy)
return self.replace(new_feats, new_coords)
def type(self, dtype):
new_feats = self.feats.type(dtype)
return self.replace(new_feats)
def cpu(self) -> 'SparseTensor':
new_feats = self.feats.cpu()
new_coords = self.coords.cpu()
return self.replace(new_feats, new_coords)
def cuda(self) -> 'SparseTensor':
new_feats = self.feats.cuda()
new_coords = self.coords.cuda()
return self.replace(new_feats, new_coords)
def half(self) -> 'SparseTensor':
new_feats = self.feats.half()
return self.replace(new_feats)
def float(self) -> 'SparseTensor':
new_feats = self.feats.float()
return self.replace(new_feats)
def detach(self) -> 'SparseTensor':
new_coords = self.coords.detach()
new_feats = self.feats.detach()
return self.replace(new_feats, new_coords)
def reshape(self, *shape) -> 'SparseTensor':
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
return self.replace(new_feats)
def unbind(self, dim: int) -> List['SparseTensor']:
return sparse_unbind(self, dim)
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
if config.CONV == 'torchsparse':
new_data = self.SparseTensorData(
feats=feats,
coords=self.data.coords if coords is None else coords,
stride=self.data.stride,
spatial_range=self.data.spatial_range,
)
new_data._caches = self.data._caches
elif config.CONV == 'spconv':
new_data = self.SparseTensorData(
self.data.features.reshape(self.data.features.shape[0], -1),
self.data.indices,
self.data.spatial_shape,
self.data.batch_size,
self.data.grid,
self.data.voxel_num,
self.data.indice_dict
)
new_data._features = feats
new_data.benchmark = self.data.benchmark
new_data.benchmark_record = self.data.benchmark_record
new_data.thrust_allocator = self.data.thrust_allocator
new_data._timer = self.data._timer
new_data.force_algo = self.data.force_algo
new_data.int8_scale = self.data.int8_scale
if coords is not None:
new_data.indices = coords
else:
new_data = {
'feats': feats,
'coords': self.data['coords'] if coords is None else coords,
}
new_tensor = SparseTensor(
new_data,
shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None,
scale=self._scale,
spatial_cache=self._spatial_cache
)
return new_tensor
def to_dense(self) -> torch.Tensor:
if config.CONV == 'torchsparse':
return self.data.dense()
elif config.CONV == 'spconv':
return self.data.dense()
else:
spatial_shape = self.spatial_shape
ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device)
idx = [self.coords[:, 0].contiguous(), slice(None)] + self.coords[:, 1:].contiguous().unbind(1)
ret[tuple(idx)] = self.feats
return ret
@staticmethod
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
N, C = dim
x = torch.arange(aabb[0], aabb[3] + 1)
y = torch.arange(aabb[1], aabb[4] + 1)
z = torch.arange(aabb[2], aabb[5] + 1)
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
coords = torch.cat([
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
coords.repeat(N, 1),
], dim=1).to(dtype=torch.int32, device=device)
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
return SparseTensor(feats=feats, coords=coords)
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
new_cache = {}
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
if k in self._spatial_cache:
new_cache[k] = self._spatial_cache[k]
if k in other._spatial_cache:
if k not in new_cache:
new_cache[k] = other._spatial_cache[k]
else:
new_cache[k].update(other._spatial_cache[k])
return new_cache
def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor':
if isinstance(other, torch.Tensor):
try:
other = torch.broadcast_to(other, self.shape)
other = other[self.batch_boardcast_map]
except:
pass
if isinstance(other, VarLenTensor):
other = other.feats
new_feats = op(self.feats, other)
new_tensor = self.replace(new_feats)
if isinstance(other, SparseTensor):
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
return new_tensor
def __getitem__(self, idx):
if isinstance(idx, int):
idx = [idx]
elif isinstance(idx, slice):
idx = range(*idx.indices(self.shape[0]))
elif isinstance(idx, list):
assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}"
elif isinstance(idx, torch.Tensor):
if idx.dtype == torch.bool:
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
idx = idx.nonzero().squeeze(1)
elif idx.dtype in [torch.int32, torch.int64]:
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
else:
raise ValueError(f"Unknown index type: {idx.dtype}")
else:
raise ValueError(f"Unknown index type: {type(idx)}")
new_coords = []
new_feats = []
new_layout = []
new_shape = torch.Size([len(idx)] + list(self.shape[1:]))
start = 0
for new_idx, old_idx in enumerate(idx):
new_coords.append(self.coords[self.layout[old_idx]].clone())
new_coords[-1][:, 0] = new_idx
new_feats.append(self.feats[self.layout[old_idx]])
new_layout.append(slice(start, start + len(new_coords[-1])))
start += len(new_coords[-1])
new_coords = torch.cat(new_coords, dim=0).contiguous()
new_feats = torch.cat(new_feats, dim=0).contiguous()
new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape)
new_tensor.register_spatial_cache('layout', new_layout)
return new_tensor
def clear_spatial_cache(self) -> None:
"""
Clear all spatial caches.
"""
self._spatial_cache = {}
def register_spatial_cache(self, key, value) -> None:
"""
Register a spatial cache.
The spatial cache can be any thing you want to cache.
The registery and retrieval of the cache is based on current scale.
"""
scale_key = str(self._scale)
if scale_key not in self._spatial_cache:
self._spatial_cache[scale_key] = {}
self._spatial_cache[scale_key][key] = value
def get_spatial_cache(self, key=None):
"""
Get a spatial cache.
"""
scale_key = str(self._scale)
cur_scale_cache = self._spatial_cache.get(scale_key, {})
if key is None:
return cur_scale_cache
return cur_scale_cache.get(key, None)
def __repr__(self) -> str:
return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})"
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
"""
Concatenate a list of sparse tensors.
Args:
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
"""
if dim == 0:
start = 0
coords = []
for input in inputs:
coords.append(input.coords.clone())
coords[-1][:, 0] += start
start += input.shape[0]
coords = torch.cat(coords, dim=0)
feats = torch.cat([input.feats for input in inputs], dim=0)
output = SparseTensor(
coords=coords,
feats=feats,
)
else:
feats = torch.cat([input.feats for input in inputs], dim=dim)
output = inputs[0].replace(feats)
return output
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
"""
Unbind a sparse tensor along a dimension.
Args:
input (SparseTensor): Sparse tensor to unbind.
dim (int): Dimension to unbind.
"""
if dim == 0:
return [input[i] for i in range(input.shape[0])]
else:
feats = input.feats.unbind(dim)
return [input.replace(f) for f in feats]
from typing import *
CONV = 'flex_gemm'
DEBUG = False
ATTN = 'flash_attn'
# ROCm GFX1201 workaround: when True, use chunked explicit-GEMM (im2col + torch.mm) instead
# of flex_gemm Triton kernels for any sparse conv where N > ROCM_SAFE_CHUNK.
# Set ROCM_SAFE_SPCONV=1 in env to enable, or call set_rocm_safe_spconv(True).
ROCM_SAFE_SPCONV = False
def __from_env():
import os
global CONV
global DEBUG
global ATTN
global ROCM_SAFE_SPCONV
env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND')
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND')
if env_sparse_attn_backend is None:
env_sparse_attn_backend = os.environ.get('ATTN_BACKEND')
if env_sparse_conv_backend is not None and env_sparse_conv_backend in ['none', 'spconv', 'torchsparse', 'flex_gemm']:
CONV = env_sparse_conv_backend
if env_sparse_debug is not None:
DEBUG = env_sparse_debug == '1'
if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa']:
ATTN = env_sparse_attn_backend
if os.environ.get('ROCM_SAFE_SPCONV') == '1':
ROCM_SAFE_SPCONV = True
print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}; ROCM_SAFE_SPCONV: {ROCM_SAFE_SPCONV}")
__from_env()
def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']):
global CONV
CONV = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
def set_attn_backend(backend: Literal['xformers', 'flash_attn', 'sdpa']):
global ATTN
ATTN = backend
def set_rocm_safe_spconv(enabled: bool):
global ROCM_SAFE_SPCONV
ROCM_SAFE_SPCONV = enabled
from .conv import SparseConv3d, SparseInverseConv3d
from . import config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment