Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
03b2a725
Unverified
Commit
03b2a725
authored
Feb 25, 2021
by
Rick Ho
Committed by
GitHub
Feb 25, 2021
Browse files
Merge pull request #6 from xfmr-xl
Test Transformer-XL
parents
e86dea53
0a942e3f
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
88 additions
and
620 deletions
+88
-620
.gitignore
.gitignore
+3
-0
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+59
-502
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+9
-110
examples/transformer-xl/utils/proj_adaptive_softmax.py
examples/transformer-xl/utils/proj_adaptive_softmax.py
+10
-6
fmoe/transformer.py
fmoe/transformer.py
+6
-2
No files found.
.gitignore
View file @
03b2a725
...
@@ -10,3 +10,6 @@ a.out
...
@@ -10,3 +10,6 @@ a.out
build
build
*swp
*swp
logs
logs
examples/transformer-xl/data
examples/data
examples/transformer-xl/LM-TFM-enwik8
examples/transformer-xl/mem_transformer.py
View file @
03b2a725
This diff is collapsed.
Click to expand it.
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
03b2a725
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--batch_size
22
\
--multi_gpu
\
--multi_gpu
\
--gpu0_bsz
4
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
03b2a725
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
math
import
math
import
os
,
sys
import
os
,
sys
import
itertools
import
itertools
import
pathlib
import
numpy
as
np
import
numpy
as
np
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
help
=
'location of the data corpus'
)
...
@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
...
@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top-k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
args
.
tied
=
not
args
.
not_tied
assert
args
.
moe_num_expert
>=
args
.
moe_top_k
,
"must have moe-num-expert >= moe-top_k"
if
args
.
d_embed
<
0
:
if
args
.
d_embed
<
0
:
args
.
d_embed
=
args
.
d_model
args
.
d_embed
=
args
.
d_model
...
@@ -305,7 +286,8 @@ else:
...
@@ -305,7 +286,8 @@ else:
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
...
@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
@@ -434,33 +413,15 @@ def evaluate(eval_iter):
...
@@ -434,33 +413,15 @@ def evaluate(eval_iter):
break
break
ret
=
model
(
data
,
target
,
*
mems
)
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
total_len
+=
seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
model
.
train
()
return
total_loss
/
total_len
return
total_loss
/
total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def
train
():
def
train
():
...
@@ -471,11 +432,6 @@ def train():
...
@@ -471,11 +432,6 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
else
:
mems
=
tuple
()
mems
=
tuple
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
model
.
zero_grad
()
...
@@ -487,7 +443,6 @@ def train():
...
@@ -487,7 +443,6 @@ def train():
target_i
=
target_chunks
[
i
].
contiguous
()
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -497,28 +452,12 @@ def train():
...
@@ -497,28 +452,12 @@ def train():
else
:
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
else
:
else
:
loss
.
backward
()
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
train_loss
+=
loss
.
float
().
item
()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
optimizer
.
clip_master_grads
(
args
.
clip
)
...
@@ -557,39 +496,12 @@ def train():
...
@@ -557,39 +496,12 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging
(
log_str
)
logging
(
log_str
)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss
=
0
train_loss
=
0
log_start_time
=
time
.
time
()
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
val_loss
=
evaluate
(
va_iter
)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
'| valid loss {:5.2f}'
.
format
(
...
@@ -599,11 +511,6 @@ def train():
...
@@ -599,11 +511,6 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging
(
log_str
)
logging
(
log_str
)
logging
(
'-'
*
100
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
# Save the model if the validation loss is the best we've seen so far.
...
@@ -653,7 +560,6 @@ para_model = model.to(device)
...
@@ -653,7 +560,6 @@ para_model = model.to(device)
# Run on test data.
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
test_loss
=
evaluate
(
te_iter
)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
...
@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
else
:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
test_loss
,
math
.
exp
(
test_loss
)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging
(
'='
*
100
)
logging
(
'='
*
100
)
examples/transformer-xl/utils/proj_adaptive_softmax.py
View file @
03b2a725
...
@@ -9,6 +9,10 @@ import torch.nn.functional as F
...
@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
class
Projection
(
nn
.
Module
):
def
__init__
(
self
,
out_feat
,
in_feat
):
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_feat
,
in_feat
))
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
):
keep_order
=
False
):
...
@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_projs
=
nn
.
Parameter
List
()
self
.
out_projs
=
nn
.
Module
List
()
if
div_val
==
1
:
if
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
d_proj
!=
d_embed
:
if
d_proj
!=
d_embed
:
self
.
out_projs
.
append
(
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_embed
)
)
Projection
(
d_proj
,
d_embed
)
)
)
else
:
else
:
self
.
out_projs
.
append
(
None
)
self
.
out_projs
.
append
(
None
)
...
@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i
=
d_embed
//
(
div_val
**
i
)
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
out_projs
.
append
(
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_emb_i
)
)
Projection
(
d_proj
,
d_emb_i
)
)
)
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
...
@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
self
.
n_clusters
==
0
:
if
self
.
n_clusters
==
0
:
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
]
.
weight
if
self
.
out_projs
[
0
]
is
not
None
else
None
)
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
else
:
...
@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights
.
append
(
weight_i
)
weights
.
append
(
weight_i
)
biases
.
append
(
bias_i
)
biases
.
append
(
bias_i
)
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
.
weight
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
...
@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
...
@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
i
==
0
:
if
i
==
0
:
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
else
:
else
:
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
.
weight
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
...
...
fmoe/transformer.py
View file @
03b2a725
...
@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
...
@@ -49,10 +49,12 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
top_k
=
2
,
do_lnorm
=
False
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
pre_lnorm
=
False
,
expert_dp_comm
=
'none'
expert_dp_comm
=
'none'
,
dropout
=
0.1
):
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
...
@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
...
@@ -72,7 +74,9 @@ class FMoETransformerMLP(FMoE):
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
+
inp
output
=
super
().
forward
(
inp
)
output
=
self
.
dropout
(
output
)
output
+=
inp
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
output
=
self
.
layer_norm
(
output
)
return
output
.
reshape
(
original_shape
)
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment