"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f6fb3282b18f44f14bcb95a34a16203906df992a"
Commit 2338a26e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

incoporate customized cuda moe input xfmr-xl

parent d036918c
...@@ -5,8 +5,6 @@ import torch ...@@ -5,8 +5,6 @@ import torch
import moe_cuda import moe_cuda
torch.manual_seed(42)
torch.cuda.manual_seed(42)
class MOEFunction(Function): class MOEFunction(Function):
@staticmethod @staticmethod
...@@ -21,12 +19,12 @@ class MOEFunction(Function): ...@@ -21,12 +19,12 @@ class MOEFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
print("grad_out", grad_out) # print("grad_out", grad_out)
print("input", ctx.saved_tensors[0]) # print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward( grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors) grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:] out_feat, in_feat = grad_weight.size()[1:]
print("grad_weight_column_major", grad_weight.flatten()) # print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major return grad_inp, None, grad_weight_row_major
...@@ -47,7 +45,7 @@ class MOELayer(nn.Module): ...@@ -47,7 +45,7 @@ class MOELayer(nn.Module):
self.weight.data[i] = linear.weight.data self.weight.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
return MOEFunction.apply(inp, gate, self.weight) return MOEFunction.apply(inp, gate.int(), self.weight)
class MOELayer_raw(nn.Module): class MOELayer_raw(nn.Module):
...@@ -75,6 +73,8 @@ class MOELayer_raw(nn.Module): ...@@ -75,6 +73,8 @@ class MOELayer_raw(nn.Module):
def test(): def test():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
batch_size = 4 batch_size = 4
num_expert = 4 num_expert = 4
in_feat = 2 in_feat = 2
......
...@@ -9,6 +9,8 @@ import torch.nn as nn ...@@ -9,6 +9,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# import torch_sparse # import torch_sparse
from cuda.moe import MOELayer
sys.path.append('utils') sys.path.append('utils')
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
from log_uniform_sampler import LogUniformSampler, sample_logits from log_uniform_sampler import LogUniformSampler, sample_logits
...@@ -31,9 +33,76 @@ class PositionalEmbedding(nn.Module): ...@@ -31,9 +33,76 @@ class PositionalEmbedding(nn.Module):
else: else:
return pos_emb[:,None,:] return pos_emb[:,None,:]
class MoEPositionwiseFF(nn.Module): class CustomizedMoEPositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=2, num_expert=4):
super(CustomizedMoEPositionwiseFF, self).__init__()
print("CustomizedMoEPositionwiseFF num_expert=%d top_k=%d" % (num_expert, top_k))
self.top_k = top_k
assert num_expert >= top_k
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.gate = nn.Linear(d_model, d_inner)
self.moe1 = MOELayer(num_expert=num_expert, in_feat=d_model, out_feat=d_inner)
self.moe2 = MOELayer(num_expert=num_expert, in_feat=d_inner, out_feat=d_model)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
self.dropout = nn.Dropout(dropout)
self.reset_parameter()
def reset_parameter(self):
pass
def forward(self, inp):
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(gate, k=self.top_k, dim=-1, largest=True, sorted=False) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
core_out = []
inp = inp.view(-1, self.d_model)
# inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0)
for i in range(self.top_k):
print("top %d" % i)
gate_idx = gate_top_k_idx[:, i].contiguous()
print(inp.size(), gate_idx.size())
x = self.moe1(inp, gate_idx)
x = self.dropout(F.relu(x))
# x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_idx)
x = self.dropout(x) # (BxL) x d_model
output.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model)
output = core_out + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
return output
class MoEPositionwiseFFRaw(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, top_k=64):
super(MoEPositionwiseFF, self).__init__() super(MoEPositionwiseFFRaw, self).__init__()
print("MoEPositionwiseFF") print("MoEPositionwiseFF")
self.top_k = top_k self.top_k = top_k
...@@ -820,7 +889,7 @@ class DecoderLayer(nn.Module): ...@@ -820,7 +889,7 @@ class DecoderLayer(nn.Module):
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) # self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None):
...@@ -840,7 +909,7 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -840,7 +909,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs) **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
...@@ -861,7 +930,7 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -861,7 +930,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, **kwargs)
self.pos_ff = MultiHeadHierarchicalMoEPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......
#!/bin/bash #!/bin/bash
export LD_LIBRARY_PATH=/home/jiezhong/miniconda3/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
if [[ $1 == 'train' ]]; then if [[ $1 == 'train' ]]; then
echo 'Run training...' echo 'Run training...'
python train.py \ python train.py \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment