Commit 9c0053b7 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #953 canceled with stages
import torch
import torch.nn as nn
import math
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEmbedding, self).__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return self.pe[:, :x.size(1)]
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(FixedEmbedding, self).__init__()
w = torch.zeros(c_in, d_model).float()
w.require_grad = False
position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (torch.arange(0, d_model, 2).float()
* -(math.log(10000.0) / d_model)).exp()
w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)
def forward(self, x):
return self.emb(x).detach()
class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type='fixed', freq='h'):
super(TemporalEmbedding, self).__init__()
minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13
Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
if freq == 't':
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)
def forward(self, x):
x = x.long()
minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
self, 'minute_embed') else 0.
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])
return hour_x + weekday_x + day_x + month_x + minute_x
class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type='timeF', freq='h'):
super(TimeFeatureEmbedding, self).__init__()
freq_map = {'h': 4, 't': 5, 's': 6,
'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)
def forward(self, x):
return self.embed(x)
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding, self).__init__()
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
d_model=d_model, embed_type=embed_type, freq=freq)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = self.value_embedding(
x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
class DataEmbedding_inverted(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding_inverted, self).__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
x = self.value_embedding(x)
else:
# the potential to take covariates (e.g. timestamps) as tokens
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
# x: [Batch Variate d_model]
return self.dropout(x)
import torch
import torch.nn as nn
import numpy as np
from math import sqrt
from utils.masking import TriangularCausalMask, ProbMask
from reformer_pytorch import LSHSelfAttention
from einops import rearrange
# Code implementation from https://github.com/thuml/Flowformer
class FlowAttention(nn.Module):
def __init__(self, attention_dropout=0.1):
super(FlowAttention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
def kernel_method(self, x):
return torch.sigmoid(x)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
# kernel
queries = self.kernel_method(queries)
keys = self.kernel_method(keys)
# incoming and outgoing
normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6))
normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6))
# reweighting
normalizer_row_refine = (
torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6))
normalizer_col_refine = (
torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6))
# competition and allocation
normalizer_row_refine = torch.sigmoid(
normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2])))
normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2] # B h L vis
# multiply
kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None])
x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1,
2).contiguous()
return x, None
# Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch
class FlashAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FlashAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def flash_attention_forward(self, Q, K, V, mask=None):
BLOCK_SIZE = 32
NEG_INF = -1e10 # -infinity
EPSILON = 1e-10
# mask = torch.randint(0, 2, (128, 8)).to(device='cuda')
O = torch.zeros_like(Q, requires_grad=True)
l = torch.zeros(Q.shape[:-1])[..., None]
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF
O = O.to(device='cuda')
l = l.to(device='cuda')
m = m.to(device='cuda')
Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
KV_BLOCK_SIZE = BLOCK_SIZE
Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
if mask is not None:
mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))
Tr = len(Q_BLOCKS)
Tc = len(K_BLOCKS)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
for j in range(Tc):
Kj = K_BLOCKS[j]
Vj = V_BLOCKS[j]
if mask is not None:
maskj = mask_BLOCKS[j]
for i in range(Tr):
Qi = Q_BLOCKS[i]
Oi = O_BLOCKS[i]
li = l_BLOCKS[i]
mi = m_BLOCKS[i]
scale = 1 / np.sqrt(Q.shape[-1])
Qi_scaled = Qi * scale
S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)
if mask is not None:
# Masking
maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')
S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)
m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
P_ij = torch.exp(S_ij - m_block_ij)
if mask is not None:
# Masking
P_ij = torch.where(maskj_temp > 0, P_ij, 0.)
l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
mi_new = torch.maximum(m_block_ij, mi)
li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij
O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + (
torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
l_BLOCKS[i] = li_new
m_BLOCKS[i] = mi_new
O = torch.cat(O_BLOCKS, dim=2)
l = torch.cat(l_BLOCKS, dim=2)
m = torch.cat(m_BLOCKS, dim=2)
return O, l, m
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
res = \
self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3),
attn_mask)[0]
return res.permute(0, 2, 1, 3).contiguous(), None
class FullAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(FullAttention, self).__init__()
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, H, E = queries.shape
_, S, _, D = values.shape
scale = self.scale or 1. / sqrt(E)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
if self.mask_flag:
if attn_mask is None:
attn_mask = TriangularCausalMask(B, L, device=queries.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
V = torch.einsum("bhls,bshd->blhd", A, values)
if self.output_attention:
return (V.contiguous(), A)
else:
return (V.contiguous(), None)
# Code implementation from https://github.com/zhouhaoyi/Informer2020
class ProbAttention(nn.Module):
def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
super(ProbAttention, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
# real U = U_part(factor*ln(L_k))*L_q
index_sample = torch.randint(L_K, (L_Q, sample_k))
K_sample = K_expand[:, :, torch.arange(
L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(
Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
# find the Top_k query with sparisty measurement
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
M_top = M.topk(n_top, sorted=False)[1]
# use the reduced Q to calculate Q_K
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # factor*ln(L_q)
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
return Q_K, M_top
def _get_initial_context(self, V, L_Q):
B, H, L_V, D = V.shape
if not self.mask_flag:
# V_sum = V.sum(dim=-2)
V_sum = V.mean(dim=-2)
contex = V_sum.unsqueeze(-2).expand(B, H,
L_Q, V_sum.shape[-1]).clone()
else: # use mask
# requires that L_Q == L_V, i.e. for self-attention only
assert (L_Q == L_V)
contex = V.cumsum(dim=-2)
return contex
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
B, H, L_V, D = V.shape
if self.mask_flag:
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
scores.masked_fill_(attn_mask.mask, -np.inf)
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
context_in[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
index, :] = torch.matmul(attn, V).type_as(context_in)
if self.output_attention:
attns = (torch.ones([B, H, L_V, L_V]) /
L_V).type_as(attn).to(attn.device)
attns[torch.arange(B)[:, None, None], torch.arange(H)[
None, :, None], index, :] = attn
return (context_in, attns)
else:
return (context_in, None)
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L_Q, H, D = queries.shape
_, L_K, _, _ = keys.shape
queries = queries.transpose(2, 1)
keys = keys.transpose(2, 1)
values = values.transpose(2, 1)
U_part = self.factor * \
np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
u = self.factor * \
np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
U_part = U_part if U_part < L_K else L_K
u = u if u < L_Q else L_Q
scores_top, index = self._prob_QK(
queries, keys, sample_k=U_part, n_top=u)
# add scale factor
scale = self.scale or 1. / sqrt(D)
if scale is not None:
scores_top = scores_top * scale
# get the context
context = self._get_initial_context(values, L_Q)
# update the context with selected top_k queries
context, attn = self._update_context(
context, values, scores_top, index, L_Q, attn_mask)
return context.contiguous(), attn
class AttentionLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None):
super(AttentionLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_attention = attention
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_attention(
queries,
keys,
values,
attn_mask,
tau=tau,
delta=delta
)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class ReformerLayer(nn.Module):
def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None, causal=False, bucket_size=4, n_hashes=4):
super().__init__()
self.bucket_size = bucket_size
self.attn = LSHSelfAttention(
dim=d_model,
heads=n_heads,
bucket_size=bucket_size,
n_hashes=n_hashes,
causal=causal
)
def fit_length(self, queries):
# inside reformer: assert N % (bucket_size * 2) == 0
B, N, C = queries.shape
if N % (self.bucket_size * 2) == 0:
return queries
else:
# fill the time series
fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
def forward(self, queries, keys, values, attn_mask, tau, delta):
# in Reformer: defalut queries=keys
B, N, C = queries.shape
queries = self.attn(self.fit_length(queries))[:, :N, :]
return queries, None
import torch.nn as nn
import torch.nn.functional as F
class ConvLayer(nn.Module):
def __init__(self, c_in):
super(ConvLayer, self).__init__()
self.downConv = nn.Conv1d(in_channels=c_in,
out_channels=c_in,
kernel_size=3,
padding=2,
padding_mode='circular')
self.norm = nn.BatchNorm1d(c_in)
self.activation = nn.ELU()
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.downConv(x.permute(0, 2, 1))
x = self.norm(x)
x = self.activation(x)
x = self.maxPool(x)
x = x.transpose(1, 2)
return x
class EncoderLayer(nn.Module):
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None, tau=None, delta=None):
new_x, attn = self.attention(
x, x, x,
attn_mask=attn_mask,
tau=tau, delta=delta
)
x = x + self.dropout(new_x)
y = x = self.norm1(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm2(x + y), attn
class Encoder(nn.Module):
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x, attn_mask=None, tau=None, delta=None):
# x [B, L, D]
attns = []
if self.conv_layers is not None:
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
delta = delta if i == 0 else None
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class DecoderLayer(nn.Module):
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
dropout=0.1, activation="relu"):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
x = x + self.dropout(self.self_attention(
x, x, x,
attn_mask=x_mask,
tau=tau, delta=None
)[0])
x = self.norm1(x)
x = x + self.dropout(self.cross_attention(
x, cross, cross,
attn_mask=cross_mask,
tau=tau, delta=delta
)[0])
y = x = self.norm2(x)
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
return self.norm3(x + y)
class Decoder(nn.Module):
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
for layer in self.layers:
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment