"PyTorch/NLP/vscode:/vscode.git/clone" did not exist on "c056df7823f6c8d5ad4234870596871f9ece9df1"
Commit f05e915f authored by weishb's avatar weishb
Browse files

首次提交

parent 297bf637
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
FLEX_GEMM_ALGO = 'masked_implicit_gemm' # 'explicit_gemm', 'implicit_gemm', 'implicit_gemm_splitk', 'masked_implicit_gemm', 'masked_implicit_gemm_splitk'
FLEX_GEMM_HASHMAP_RATIO = 2.0 # Ratio of hashmap size to input size
from .. import config
import importlib
import torch
import torch.nn as nn
from .. import SparseTensor
_backends = {}
class SparseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
super(SparseConv3d, self).__init__()
if config.CONV not in _backends:
_backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
_backends[config.CONV].sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key)
def forward(self, x: SparseTensor) -> SparseTensor:
return _backends[config.CONV].sparse_conv3d_forward(self, x)
class SparseInverseConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
super(SparseInverseConv3d, self).__init__()
if config.CONV not in _backends:
_backends[config.CONV] = importlib.import_module(f'..conv_{config.CONV}', __name__)
_backends[config.CONV].sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, bias, indice_key)
def forward(self, x: SparseTensor) -> SparseTensor:
return _backends[config.CONV].sparse_inverse_conv3d_forward(self, x)
import math
import torch
import torch.nn as nn
from .. import SparseTensor
from . import config
from .. import config as sparse_config
from ..linear import ROCM_SAFE_CHUNK
import flex_gemm
from flex_gemm.ops.spconv import sparse_submanifold_conv3d
from flex_gemm.ops.spconv.submanifold_conv3d import SubMConv3dFunction, SubMConv3dNeighborCache
from flex_gemm.ops import utils as flex_utils
import flex_gemm.kernels as flex_kernels
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
assert stride == 1 and (padding is None), 'Currently flex_gemm implementation only support submanifold sparse convolution (stride=1, padding=None)'
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3
self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size)))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
# initialize parameters
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
# Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci)
self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous())
def _sparse_conv3d_explicit_gemm_chunked(feats, neighbor_map, weight, bias, N, V, Co, Ci):
"""
Chunked explicit-GEMM sparse conv: im2col + torch.mm in ROCM_SAFE_CHUNK-sized pieces.
Avoids the flex_gemm Triton kernel for large N on ROCm GFX1201.
"""
# weight: [Co, V, Ci] (reshaped from [Co, Kd, Kh, Kw, Ci])
# neighbor_map: [N, V] uint32 - 0xffffffff means no neighbor
weight_2d = weight.view(Co, V * Ci).t().contiguous() # [V*Ci, Co]
output = torch.zeros(N, Co, device=feats.device, dtype=feats.dtype)
for s in range(0, N, ROCM_SAFE_CHUNK):
e = min(s + ROCM_SAFE_CHUNK, N)
chunk_size = e - s
nm = neighbor_map[s:e].long() # [chunk, V]
# im2col: [chunk, V*Ci]
im2col = torch.zeros(chunk_size * V, Ci, device=feats.device, dtype=feats.dtype)
flat_nm = nm.view(-1) # [chunk*V]
valid = flat_nm != 0xffffffff
# clamp invalid indices to 0 to avoid index-out-of-bounds, then mask
safe_nm = flat_nm.clone()
safe_nm[~valid] = 0
im2col[valid] = feats[safe_nm[valid]]
im2col = im2col.view(chunk_size, V * Ci)
# GEMM: [chunk, V*Ci] @ [V*Ci, Co] -> [chunk, Co]
output[s:e] = torch.mm(im2col, weight_2d)
if bias is not None:
output = output + bias
return output
def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
flex_gemm.ops.spconv.set_algorithm(config.FLEX_GEMM_ALGO)
flex_gemm.ops.spconv.set_hashmap_ratio(config.FLEX_GEMM_HASHMAP_RATIO)
Co, Kd, Kh, Kw, Ci = self.weight.shape
N = x.feats.shape[0]
V = Kd * Kh * Kw
neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}'
neighbor_cache = x.get_spatial_cache(neighbor_cache_key)
# ROCm safe spconv: build neighbor map normally, then use chunked torch.mm instead of Triton
if sparse_config.ROCM_SAFE_SPCONV and N > ROCM_SAFE_CHUNK:
from flex_gemm.ops.spconv.submanifold_conv3d import SubMConv3dFunction
from flex_gemm.ops import utils as flex_utils
import flex_gemm.kernels as flex_kernels
from flex_gemm.ops.spconv import Algorithm
if neighbor_cache is None:
# Build neighbor map using the HIP hash kernel (small/fast operation)
hashmap_keys, hashmap_vals = flex_utils.init_hashmap(
torch.Size([*x.shape, *x.spatial_shape]),
int(config.FLEX_GEMM_HASHMAP_RATIO * N),
x.feats.device,
)
neighbor_map = flex_kernels.cuda.hashmap_build_submanifold_conv_neighbour_map_cuda(
hashmap_keys, hashmap_vals,
x.coords,
x.spatial_shape[0], x.spatial_shape[1], x.spatial_shape[2],
Kw, Kh, Kd,
self.dilation[0], self.dilation[1], self.dilation[2],
)
# Store minimal cache so we skip rebuild next call
from flex_gemm.ops.spconv.submanifold_conv3d import SubMConv3dNeighborCache
neighbor_cache_ = SubMConv3dNeighborCache(neighbor_map=neighbor_map)
x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
else:
neighbor_map = neighbor_cache['neighbor_map']
weight_flat = self.weight.reshape(Co, V, Ci)
out = _sparse_conv3d_explicit_gemm_chunked(
x.feats, neighbor_map, weight_flat, self.bias, N, V, Co, Ci
)
print(f"[ROCM_SAFE_SPCONV] N={N} used chunked explicit GEMM (V={V})")
return x.replace(out)
# Normal path: flex_gemm Triton kernel
out, neighbor_cache_ = sparse_submanifold_conv3d(
x.feats,
x.coords,
torch.Size([*x.shape, *x.spatial_shape]),
self.weight,
self.bias,
neighbor_cache,
self.dilation
)
if neighbor_cache is None:
x.register_spatial_cache(neighbor_cache_key, neighbor_cache_)
return x.replace(out)
def sparse_inverse_conv3d_init(self, *args, **kwargs):
raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet')
def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
raise NotImplementedError('SparseInverseConv3d with flex_gemm is not implemented yet')
"""
Native PyTorch sparse convolution implementation.
No CUDA kernels, no Triton - works on any PyTorch backend (CUDA, ROCm, CPU).
MUCH SLOWER than optimized backends but numerically stable.
Use for debugging or when other backends fail.
"""
import math
import torch
import torch.nn as nn
from .. import SparseTensor
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
"""
Initialize sparse 3D convolution layer.
"""
assert stride == 1 and padding is None, \
'Native implementation only supports submanifold sparse convolution (stride=1, padding=None)'
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size,) * 3
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride,) * 3
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation,) * 3
# Weight shape: (out_channels, kernel_d, kernel_h, kernel_w, in_channels)
# Matches FlexGEMM's layout for compatibility
self.weight = nn.Parameter(torch.empty((out_channels, *self.kernel_size, in_channels)))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
# Initialize parameters
torch.nn.init.kaiming_uniform_(self.weight.view(out_channels, -1, in_channels), a=math.sqrt(5))
if self.bias is not None:
fan_in = in_channels * math.prod(self.kernel_size)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
def _build_coord_map(coords):
"""Build hash map from coordinates to indices."""
coord_map = {}
for i in range(coords.shape[0]):
key = (int(coords[i, 0].item()), int(coords[i, 1].item()),
int(coords[i, 2].item()), int(coords[i, 3].item()))
coord_map[key] = i
return coord_map
def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
"""
Forward pass for native sparse 3D convolution.
Uses precomputed neighbor cache for efficiency.
"""
coords = x.coords # [N, 4] - (batch_idx, x, y, z)
feats = x.feats # [N, C_in]
N = coords.shape[0]
C_out = self.weight.shape[0]
Kd, Kh, Kw = self.kernel_size
dk, dh, dw = self.dilation
device = feats.device
dtype = feats.dtype
# Center offsets
kd_c, kh_c, kw_c = Kd // 2, Kh // 2, Kw // 2
# Check for cached neighbor list
cache_key = f'NativeConv3d_neighbors_{Kw}x{Kh}x{Kd}_d{dk}{dh}{dw}'
neighbor_cache = x.get_spatial_cache(cache_key)
if neighbor_cache is None:
# Build coordinate map
coord_map = _build_coord_map(coords)
# Build neighbor lists for each voxel
# neighbor_list[i] = [(kernel_idx, neighbor_feat_idx), ...]
neighbor_lists = []
for i in range(N):
b = int(coords[i, 0].item())
cx, cy, cz = int(coords[i, 1].item()), int(coords[i, 2].item()), int(coords[i, 3].item())
neighbors = []
for kd in range(Kd):
for kh in range(Kh):
for kw in range(Kw):
nx = cx + (kd - kd_c) * dk
ny = cy + (kh - kh_c) * dh
nz = cz + (kw - kw_c) * dw
key = (b, nx, ny, nz)
if key in coord_map:
# Store as flat kernel index and neighbor index
k_idx = kd * Kh * Kw + kh * Kw + kw
neighbors.append((k_idx, coord_map[key]))
neighbor_lists.append(neighbors)
neighbor_cache = neighbor_lists
x.register_spatial_cache(cache_key, neighbor_cache)
# Compute output
output = torch.zeros(N, C_out, device=device, dtype=dtype)
# Flatten weight for faster indexing: [Kd*Kh*Kw, C_out, C_in]
weight_flat = self.weight.view(-1, C_out, self.in_channels)
for i in range(N):
for k_idx, n_idx in neighbor_cache[i]:
# weight_flat[k_idx] has shape [C_out, C_in]
output[i] += feats[n_idx] @ weight_flat[k_idx].T
if self.bias is not None:
output = output + self.bias
return x.replace(output)
def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
"""
Initialize sparse inverse 3D convolution (transposed/deconvolution).
This is used in the decoder for upsampling.
"""
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size,) * 3
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride,) * 3
self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation,) * 3
# Weight shape: (in_channels, kernel_d, kernel_h, kernel_w, out_channels)
# Note: For transposed conv, we swap in/out channels
self.weight = nn.Parameter(torch.empty((in_channels, *self.kernel_size, out_channels)))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
# Initialize
torch.nn.init.kaiming_uniform_(self.weight.view(in_channels, -1, out_channels), a=math.sqrt(5))
if self.bias is not None:
fan_in = in_channels * math.prod(self.kernel_size)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
"""
Forward pass for sparse inverse 3D convolution.
For inverse convolution, each input voxel scatters to multiple output positions.
This is essentially the transpose of the forward convolution.
NOTE: This implementation assumes stride=1 (no actual upsampling).
For stride>1 upsampling, the output coordinates would be different from input.
"""
coords = x.coords # [N, 4]
feats = x.feats # [N, C_in]
N = coords.shape[0]
C_out = self.weight.shape[-1] # out_channels is last dim for inverse conv
Kd, Kh, Kw = self.kernel_size
dk, dh, dw = self.dilation
device = feats.device
dtype = feats.dtype
kd_c, kh_c, kw_c = Kd // 2, Kh // 2, Kw // 2
# Build coordinate map
coord_map = _build_coord_map(coords)
# For stride=1 inverse conv, output has same coordinates as input
# Each output position accumulates from neighbors
output = torch.zeros(N, C_out, device=device, dtype=dtype)
# Flatten weight: [Kd*Kh*Kw, C_in, C_out]
weight_flat = self.weight.view(-1, self.in_channels, C_out)
# For each input voxel, scatter to its neighbors
for i in range(N):
b = int(coords[i, 0].item())
cx, cy, cz = int(coords[i, 1].item()), int(coords[i, 2].item()), int(coords[i, 3].item())
for kd in range(Kd):
for kh in range(Kh):
for kw in range(Kw):
# Neighbor coordinate
nx = cx + (kd - kd_c) * dk
ny = cy + (kh - kh_c) * dh
nz = cz + (kw - kw_c) * dw
key = (b, nx, ny, nz)
if key in coord_map:
n_idx = coord_map[key]
k_idx = kd * Kh * Kw + kh * Kw + kw
# Scatter: output[neighbor] += feats[i] @ weight[k_idx]
output[n_idx] += feats[i] @ weight_flat[k_idx]
if self.bias is not None:
output = output + self.bias
return x.replace(output)
# ============================================================================
# Vectorized implementation (faster but more memory)
# ============================================================================
def sparse_conv3d_forward_vectorized(self, x: SparseTensor) -> SparseTensor:
"""
Vectorized implementation using batch operations.
Faster than loop version but uses more memory.
"""
coords = x.coords
feats = x.feats
N = coords.shape[0]
C_in = feats.shape[1]
C_out = self.weight.shape[0]
Kd, Kh, Kw = self.kernel_size
dk, dh, dw = self.dilation
device = feats.device
dtype = feats.dtype
kd_c, kh_c, kw_c = Kd // 2, Kh // 2, Kw // 2
# Build coordinate map
coord_map = _build_coord_map(coords)
# Build all neighbor pairs
src_indices = [] # Input voxel index
dst_indices = [] # Output voxel index (same as src for stride=1)
kernel_indices = [] # Which kernel weight to use
for i in range(N):
b = int(coords[i, 0].item())
cx, cy, cz = int(coords[i, 1].item()), int(coords[i, 2].item()), int(coords[i, 3].item())
for kd in range(Kd):
for kh in range(Kh):
for kw in range(Kw):
nx = cx + (kd - kd_c) * dk
ny = cy + (kh - kh_c) * dh
nz = cz + (kw - kw_c) * dw
key = (b, nx, ny, nz)
if key in coord_map:
n_idx = coord_map[key]
k_idx = kd * Kh * Kw + kh * Kw + kw
src_indices.append(n_idx)
dst_indices.append(i)
kernel_indices.append(k_idx)
if len(src_indices) == 0:
# No neighbors found
output = torch.zeros(N, C_out, device=device, dtype=dtype)
if self.bias is not None:
output = output + self.bias
return x.replace(output)
# Convert to tensors
src_indices = torch.tensor(src_indices, device=device, dtype=torch.long)
dst_indices = torch.tensor(dst_indices, device=device, dtype=torch.long)
kernel_indices = torch.tensor(kernel_indices, device=device, dtype=torch.long)
# Gather features: [num_pairs, C_in]
pair_feats = feats[src_indices]
# Gather weights: [num_pairs, C_out, C_in]
weight_flat = self.weight.view(-1, C_out, C_in)
pair_weights = weight_flat[kernel_indices]
# Compute contributions: [num_pairs, C_out]
# pair_feats @ pair_weights.T -> but we need batch matmul
contributions = torch.bmm(pair_weights, pair_feats.unsqueeze(-1)).squeeze(-1)
# Scatter to output
output = torch.zeros(N, C_out, device=device, dtype=dtype)
output.scatter_add_(0, dst_indices.unsqueeze(-1).expand(-1, C_out), contributions)
if self.bias is not None:
output = output + self.bias
return x.replace(output)
\ No newline at end of file
import torch
import torch.nn as nn
from .. import SparseTensor
from . import config
import spconv.pytorch as spconv
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
algo = None
if config.SPCONV_ALGO == 'native':
algo = spconv.ConvAlgo.Native
elif config.SPCONV_ALGO == 'implicit_gemm':
algo = spconv.ConvAlgo.MaskImplicitGemm
if stride == 1 and (padding is None):
self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
else:
self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
self.padding = padding
def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
new_data = self.conv(x.data)
new_shape = [x.shape[0], self.conv.out_channels]
new_layout = None if spatial_changed else x.layout
if spatial_changed and (x.shape[0] != 1):
# spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
fwd = new_data.indices[:, 0].argsort()
bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
sorted_feats = new_data.features[fwd]
sorted_coords = new_data.indices[fwd]
unsorted_data = new_data
new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
out = SparseTensor(
new_data, shape=torch.Size(new_shape), layout=new_layout,
scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
spatial_cache=x._spatial_cache,
)
if spatial_changed and (x.shape[0] != 1):
out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
return out
def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
spatial_changed = any(s != 1 for s in self.stride)
if spatial_changed:
# recover the original spconv order
data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
data = data.replace_feature(x.feats[bwd])
else:
data = x.data
new_data = self.conv(data)
new_shape = [x.shape[0], self.conv.out_channels]
new_layout = None if spatial_changed else x.layout
out = SparseTensor(
new_data, shape=torch.Size(new_shape), layout=new_layout,
scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
spatial_cache=x._spatial_cache,
)
return out
import torch
import torch.nn as nn
from .. import SparseTensor
import torchsparse
def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
def sparse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
out = self.conv(x.data)
new_shape = [x.shape[0], self.conv.out_channels]
out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
out._spatial_cache = x._spatial_cache
out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
return out
def sparse_inverse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
def sparse_inverse_conv3d_forward(self, x: SparseTensor) -> SparseTensor:
out = self.conv(x.data)
new_shape = [x.shape[0], self.conv.out_channels]
out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
out._spatial_cache = x._spatial_cache
out._scale = tuple([s / stride for s, stride in zip(x._scale, self.conv.stride)])
return out
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import VarLenTensor
__all__ = [
'SparseLinear',
'ROCM_SAFE_CHUNK',
'rocm_safe_linear',
]
# ROCm GFX1201 (RX 9070 XT) bug workaround:
# hipBLASLt and rocBLAS GEMM kernels corrupt memory (→ NaN) when N > ~800k
# for shapes like [N, K] @ [K, M] with small K/M. Chunking keeps each
# dispatch below the confirmed-safe threshold of 524288 rows.
ROCM_SAFE_CHUNK = 524_288
def rocm_safe_linear(feats: torch.Tensor, weight: torch.Tensor, bias=None) -> torch.Tensor:
"""F.linear with ROCm large-N chunking workaround."""
N = feats.shape[0]
if N <= ROCM_SAFE_CHUNK:
return F.linear(feats, weight, bias)
out = torch.empty(N, weight.shape[0], device=feats.device, dtype=feats.dtype)
for s in range(0, N, ROCM_SAFE_CHUNK):
e = min(s + ROCM_SAFE_CHUNK, N)
out[s:e] = F.linear(feats[s:e], weight, bias)
return out
class SparseLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(SparseLinear, self).__init__(in_features, out_features, bias)
#def forward(self, input: VarLenTensor) -> VarLenTensor:
# return input.replace(super().forward(input.feats))
def forward(self, input):
feats = input.feats if hasattr(input, 'feats') else input
out = rocm_safe_linear(feats, self.weight, self.bias)
if hasattr(input, 'replace'):
return input.replace(out)
return out
\ No newline at end of file
import torch
import torch.nn as nn
from . import VarLenTensor
__all__ = [
'SparseReLU',
'SparseSiLU',
'SparseGELU',
'SparseActivation'
]
class SparseReLU(nn.ReLU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
class SparseSiLU(nn.SiLU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
class SparseActivation(nn.Module):
def __init__(self, activation: nn.Module):
super().__init__()
self.activation = activation
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(self.activation(input.feats))
import torch
import torch.nn as nn
from ..utils import manual_cast
from . import VarLenTensor
from . import config
__all__ = [
'SparseGroupNorm',
'SparseLayerNorm',
'SparseGroupNorm32',
'SparseLayerNorm32',
]
class SparseGroupNorm(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
def forward(self, input: VarLenTensor) -> VarLenTensor:
nfeats = torch.zeros_like(input.feats)
for k in range(input.shape[0]):
bfeats = input.feats[input.layout[k]]
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
bfeats = super().forward(bfeats)
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
nfeats[input.layout[k]] = bfeats
return input.replace(nfeats)
class SparseLayerNorm(nn.LayerNorm):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input: VarLenTensor) -> VarLenTensor:
nfeats = torch.zeros_like(input.feats)
for k in range(input.shape[0]):
bfeats = input.feats[input.layout[k]]
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
bfeats = super().forward(bfeats)
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
nfeats[input.layout[k]] = bfeats
return input.replace(nfeats)
class SparseGroupNorm32(SparseGroupNorm):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: VarLenTensor) -> VarLenTensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
class SparseLayerNorm32(SparseLayerNorm):
"""
A LayerNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: VarLenTensor) -> VarLenTensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
from .basic import *
from .spatial2channel import *
from typing import *
import torch
import torch.nn as nn
from .. import SparseTensor
__all__ = [
'SparseDownsample',
'SparseUpsample',
]
class SparseDownsample(nn.Module):
"""
Downsample a sparse tensor by a factor of `factor`.
Implemented as average pooling.
"""
def __init__(self, factor: int, mode: Literal['mean', 'max'] = 'mean'):
super(SparseDownsample, self).__init__()
self.factor = factor
self.mode = mode
assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}'
def forward(self, x: SparseTensor) -> SparseTensor:
cache = x.get_spatial_cache(f'downsample_{self.factor}')
if cache is None:
DIM = x.coords.shape[-1] - 1
coord = list(x.coords.unbind(dim=-1))
for i in range(DIM):
coord[i+1] = coord[i+1] // self.factor
MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
code = sum([c * o for c, o in zip(coord, OFFSET)])
code, idx = code.unique(return_inverse=True)
new_coords = torch.stack(
[code // OFFSET[0]] +
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
dim=-1
)
else:
new_coords, idx = cache
new_feats = torch.scatter_reduce(
torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype),
dim=0,
index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]),
src=x.feats,
reduce=self.mode,
include_self=False,
)
out = SparseTensor(new_feats, new_coords, x._shape)
out._scale = tuple([s * self.factor for s in x._scale])
out._spatial_cache = x._spatial_cache
if cache is None:
x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx))
out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx))
out.register_spatial_cache(f'shape', torch.Size(MAX))
if self.training:
subidx = x.coords[:, 1:] % self.factor
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
subdivision[idx, subidx] = True
out.register_spatial_cache(f'subdivision', subdivision)
return out
class SparseUpsample(nn.Module):
"""
Upsample a sparse tensor by a factor of `factor`.
Implemented as nearest neighbor interpolation.
"""
def __init__(
self, factor: int
):
super(SparseUpsample, self).__init__()
self.factor = factor
def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'upsample_{self.factor}')
if cache is None:
if subdivision is None:
raise ValueError('Cache not found. Provide subdivision tensor or pair SparseUpsample with SparseDownsample.')
else:
sub = subdivision.feats
N_leaf = sub.sum(dim=-1)
subidx = sub.nonzero()[:, -1]
new_coords = x.coords.clone().detach()
new_coords[:, 1:] *= self.factor
new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
for i in range(DIM):
new_coords[:, i+1] += subidx // self.factor ** i % self.factor
idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
else:
new_coords, idx = cache
new_feats = x.feats[idx]
out = SparseTensor(new_feats, new_coords, x._shape)
out._scale = tuple([s / self.factor for s in x._scale])
if cache is not None: # only keep cache when subdiv following it
out._spatial_cache = x._spatial_cache
return out
\ No newline at end of file
from typing import *
import torch
import torch.nn as nn
from .. import SparseTensor
class SparseSpatial2Channel(nn.Module):
"""
Downsample a sparse tensor by a factor of `factor`.
Implemented as rearranging its features from spatial to channel.
"""
def __init__(self, factor: int = 2):
super(SparseSpatial2Channel, self).__init__()
self.factor = factor
def forward(self, x: SparseTensor) -> SparseTensor:
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'spatial2channel_{self.factor}')
if cache is None:
coord = list(x.coords.unbind(dim=-1))
for i in range(DIM):
coord[i+1] = coord[i+1] // self.factor
subidx = x.coords[:, 1:] % self.factor
subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
code = sum([c * o for c, o in zip(coord, OFFSET)])
code, idx = code.unique(return_inverse=True)
new_coords = torch.stack(
[code // OFFSET[0]] +
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
dim=-1
)
else:
new_coords, idx, subidx = cache
new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype)
new_feats[idx * self.factor ** DIM + subidx] = x.feats
out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM]))
out._scale = tuple([s * self.factor for s in x._scale])
out._spatial_cache = x._spatial_cache
if cache is None:
x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx))
out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx))
out.register_spatial_cache(f'shape', torch.Size(MAX))
if self.training:
subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
subdivision[idx, subidx] = True
out.register_spatial_cache(f'subdivision', subdivision)
return out
class SparseChannel2Spatial(nn.Module):
"""
Upsample a sparse tensor by a factor of `factor`.
Implemented as rearranging its features from channel to spatial.
"""
def __init__(self, factor: int = 2):
super(SparseChannel2Spatial, self).__init__()
self.factor = factor
def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) -> SparseTensor:
DIM = x.coords.shape[-1] - 1
cache = x.get_spatial_cache(f'channel2spatial_{self.factor}')
if cache is None:
if subdivision is None:
raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.')
else:
sub = subdivision.feats # [N, self.factor ** DIM]
N_leaf = sub.sum(dim=-1) # [N]
subidx = sub.nonzero()[:, -1]
new_coords = x.coords.clone().detach()
new_coords[:, 1:] *= self.factor
new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0])
for i in range(DIM):
new_coords[:, i+1] += subidx // self.factor ** i % self.factor
idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0])
else:
new_coords, idx, subidx = cache
x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1)
new_feats = x_feats[idx * self.factor ** DIM + subidx]
out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM]))
out._scale = tuple([s / self.factor for s in x._scale])
if cache is not None: # only keep cache when subdiv following it
out._spatial_cache = x._spatial_cache
return out
from .blocks import *
from .modulated import *
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment