assertseq_len<=self.max_seq_len,f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
ifposisNone:
pos=torch.arange(seq_len,device=device)
ifseq_start_posisnotNone:
pos=(pos-seq_start_pos[...,None]).clamp(min=0)
pos_emb=self.emb(pos)
pos_emb=pos_emb*self.scale
returnpos_emb
classScaledSinusoidalEmbedding(nn.Module):
def__init__(self,dim,theta=10000):
super().__init__()
assert(dim%2)==0,'dimension must be divisible by 2'
assertseq_len<=self.max_seq_len,f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
ifposisNone:
pos=torch.arange(seq_len,device=device)
ifseq_start_posisnotNone:
pos=(pos-seq_start_pos[...,None]).clamp(min=0)
pos_emb=self.emb(pos)
pos_emb=pos_emb*self.scale
returnpos_emb
classScaledSinusoidalEmbedding(nn.Module):
def__init__(self,dim,theta=10000):
super().__init__()
assert(dim%2)==0,'dimension must be divisible by 2'