Commit 9c92be55 authored by Rick Ho's avatar Rick Ho
Browse files

fit fmoe in transformer-xl

parent 5e9bb2e9
...@@ -9,8 +9,6 @@ import torch.nn as nn ...@@ -9,8 +9,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# import torch_sparse # import torch_sparse
from cuda.moe import MOELayer
sys.path.append('utils') sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
from log_uniform_sampler import LogUniformSampler, sample_logits from log_uniform_sampler import LogUniformSampler, sample_logits
...@@ -33,81 +31,8 @@ class PositionalEmbedding(nn.Module): ...@@ -33,81 +31,8 @@ class PositionalEmbedding(nn.Module):
else: else:
return pos_emb[:,None,:] return pos_emb[:,None,:]
class CustomizedMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=2, num_expert=64):
super(CustomizedMoEPositionwiseFF, self).__init__()
print("CustomizedMoEPositionwiseFF num_expert=%d top_k=%d" % (num_expert, top_k))
self.top_k = top_k
assert num_expert >= top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, num_expert)
self.moe1 = MOELayer(num_expert=num_expert, in_feat=d_model+1, out_feat=d_inner)
self.moe2 = MOELayer(num_expert=num_expert, in_feat=d_inner+1, out_feat=d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
pass
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
# gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
#core_out = []
inp = inp.view(-1, self.d_model).repeat_interleave(repeats=self.top_k, dim=0) # (BxLxtop_k) x d_model
inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0)
x = self.moe1(inp, gate_top_k_idx)
x = self.dropout(F.relu(x))
x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_top_k_idx)
x = self.dropout(x) # (BxLxtop_k) x d_model
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model
"""
for i in range(self.top_k):
gate_idx = gate_top_k_idx[:, i].contiguous()
x = self.moe1(inp, gate_idx)
x = self.dropout(F.relu(x))
x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_idx)
x = self.dropout(x) # (BxL) x d_model
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
"""
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
# A baseline naive slow implementation
class MoEPositionwiseFFRaw(nn.Module): class MoEPositionwiseFFRaw(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFFRaw, self).__init__() super(MoEPositionwiseFFRaw, self).__init__()
...@@ -158,7 +83,7 @@ class MoEPositionwiseFFRaw(nn.Module): ...@@ -158,7 +83,7 @@ class MoEPositionwiseFFRaw(nn.Module):
output = self.layer_norm(output) output = self.layer_norm(output)
return output return output
# return output, relu_out.detach()
def my_topk(x, k, inplace=True): def my_topk(x, k, inplace=True):
y = x if inplace else x.clone() y = x if inplace else x.clone()
...@@ -891,6 +816,21 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -891,6 +816,21 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return output return output
from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
def activation(x):
return self.dropout(F.relu(x))
super().__init__(num_expert=8, d_model=d_model, d_hidden=d_inner,
pre_lnorm=pre_lnorm, activation=activation)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x, bias = super().forward(x)
return x + bias
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
......
echo "=== Acquiring datasets ===" echo "=== Acquiring datasets ==="
echo "---" echo "---"
mkdir -p data mkdir -p ../data
cd data cd ../data
if [[ ! -d 'wikitext-2' ]]; then if [[ ! -d 'wikitext-2' ]]; then
echo "- Downloading WikiText-2 (WT2)" echo "- Downloading WikiText-2 (WT2)"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment