Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
b0990e4b
Commit
b0990e4b
authored
Feb 26, 2021
by
Rick Ho
Browse files
Merge branch 'master' into laekov/accfix
parents
89de2153
1cfc5462
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
133 additions
and
643 deletions
+133
-643
examples/.gitignore
examples/.gitignore
+3
-0
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+78
-503
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
+1
-0
examples/transformer-xl/train.py
examples/transformer-xl/train.py
+9
-110
examples/transformer-xl/utils/proj_adaptive_softmax.py
examples/transformer-xl/utils/proj_adaptive_softmax.py
+10
-6
fmoe/layers.py
fmoe/layers.py
+27
-2
fmoe/megatron.py
fmoe/megatron.py
+4
-5
fmoe/transformer.py
fmoe/transformer.py
+1
-17
No files found.
examples/.gitignore
0 → 100644
View file @
b0990e4b
transformer-xl/data
transformer-xl/LM-TFM-enwik8
data
examples/transformer-xl/mem_transformer.py
View file @
b0990e4b
This diff is collapsed.
Click to expand it.
examples/transformer-xl/scripts/run_enwik8_base_gshard.sh
View file @
b0990e4b
...
...
@@ -25,6 +25,7 @@ if [[ $1 == 'train' ]]; then
--batch_size
22
\
--multi_gpu
\
--gpu0_bsz
4
\
--moe
--moe-num-expert
64
--moe-top-k
2
\
${
@
:2
}
elif
[[
$1
==
'eval'
]]
;
then
echo
'Run evaluation...'
...
...
examples/transformer-xl/train.py
View file @
b0990e4b
...
...
@@ -4,7 +4,6 @@ import time
import
math
import
os
,
sys
import
itertools
import
pathlib
import
numpy
as
np
...
...
@@ -17,31 +16,6 @@ from mem_transformer import MemTransformerLM
from
utils.exp_utils
import
create_exp_dir
from
utils.data_parallel
import
BalancedDataParallel
class
AverageMeter
(
object
):
"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch Transformer Language Model'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'../data/wikitext-103'
,
help
=
'location of the data corpus'
)
...
...
@@ -167,8 +141,15 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
parser
.
add_argument
(
'--dynamic-loss-scale'
,
action
=
'store_true'
,
help
=
'Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.'
)
parser
.
add_argument
(
'--moe'
,
action
=
'store_true'
,
help
=
'replace position-wise ffn with moe position-wise ffn'
)
parser
.
add_argument
(
'--moe-num-expert'
,
type
=
int
,
default
=
64
,
help
=
'number of experts in MoE'
)
parser
.
add_argument
(
'--moe-top-k'
,
type
=
int
,
default
=
2
,
help
=
'top_k experts in hard gate of moe'
)
args
=
parser
.
parse_args
()
args
.
tied
=
not
args
.
not_tied
assert
args
.
moe_num_expert
>=
args
.
moe_top_k
,
"must have moe-num-expert >= moe-top_k"
if
args
.
d_embed
<
0
:
args
.
d_embed
=
args
.
d_model
...
...
@@ -305,7 +286,8 @@ else:
tie_projs
=
tie_projs
,
pre_lnorm
=
args
.
pre_lnorm
,
tgt_len
=
args
.
tgt_len
,
ext_len
=
args
.
ext_len
,
mem_len
=
args
.
mem_len
,
cutoffs
=
cutoffs
,
same_length
=
args
.
same_length
,
attn_type
=
args
.
attn_type
,
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
)
clamp_len
=
args
.
clamp_len
,
sample_softmax
=
args
.
sample_softmax
,
moe
=
args
.
moe
,
moe_num_expert
=
args
.
moe_num_expert
,
moe_top_k
=
args
.
moe_top_k
)
model
.
apply
(
weights_init
)
model
.
word_emb
.
apply
(
weights_init
)
# ensure embedding init is not overridden by out_layer in case of weight sharing
args
.
n_all_param
=
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])
...
...
@@ -412,9 +394,6 @@ logging('#non emb params = {}'.format(args.n_nonemb_param))
def
evaluate
(
eval_iter
):
# Turn on evaluation mode which disables dropout.
model
.
eval
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
...
...
@@ -434,33 +413,15 @@ def evaluate(eval_iter):
break
ret
=
model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
mean
()
total_loss
+=
seq_len
*
loss
.
float
().
item
()
total_len
+=
seq_len
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
# Switch back to the training mode
model
.
reset_length
(
args
.
tgt_len
,
args
.
ext_len
,
args
.
mem_len
)
model
.
train
()
return
total_loss
/
total_len
# return total_loss / total_len, avg_nnzs, act_hist, co_act_hist
def
train
():
...
...
@@ -471,11 +432,6 @@ def train():
mems
=
[
tuple
()
for
_
in
range
(
args
.
batch_chunk
)]
else
:
mems
=
tuple
()
# avg_nnzs = None
# act_hist = None
# co_act_hist = None
train_iter
=
tr_iter
.
get_varlen_iter
()
if
args
.
varlen
else
tr_iter
for
batch
,
(
data
,
target
,
seq_len
)
in
enumerate
(
train_iter
):
model
.
zero_grad
()
...
...
@@ -487,7 +443,6 @@ def train():
target_i
=
target_chunks
[
i
].
contiguous
()
ret
=
para_model
(
data_i
,
target_i
,
*
mems
[
i
])
loss
,
mems
[
i
]
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems[i] = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
/
args
.
batch_chunk
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
...
...
@@ -497,28 +452,12 @@ def train():
else
:
ret
=
para_model
(
data
,
target
,
*
mems
)
loss
,
mems
=
ret
[
0
],
ret
[
1
:]
# relu_outs, loss, mems = ret[0], ret[1], ret[2:]
loss
=
loss
.
float
().
mean
().
type_as
(
loss
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
train_loss
+=
loss
.
float
().
item
()
# acts = [(relu_out > 0).float().cpu() for relu_out in relu_outs]
# # nnzs = [act.sum().item() / act.numel() for act in acts]
# if avg_nnzs is None:
# n_layer = len(acts)
# avg_nnzs = [AverageMeter() for i in range(n_layer)]
# d_inner = acts[0].size(-1)
# act_hist = [torch.zeros(d_inner) for i in range(n_layer)]
# co_act_hist = [torch.zeros(d_inner, d_inner) for i in range(n_layer)]
# for i, act in enumerate(acts):
# nnz = act.sum().item() / act.numel()
# avg_nnzs[i].update(nnz)
# act_hist[i] += torch.sum(act, dim=[0, 1])
# co_act = torch.einsum("ija,ijb->ab", (act, act))
# co_act_hist[i] += co_act
if
args
.
fp16
:
optimizer
.
clip_master_grads
(
args
.
clip
)
...
...
@@ -557,39 +496,12 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
cur_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | ppl {:9.3f}'
.
format
(
math
.
exp
(
cur_loss
))
# final_avg_nnzs = [avg_nnzs[i].avg for i in range(len(avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_avg_nnzs)/len(final_avg_nnzs)*100,
# max(final_avg_nnzs)*100,
# )
logging
(
log_str
)
# co_act_dir = pathlib.Path(logging.keywords['log_path']).parent.joinpath("co_act")
# co_act_dir.mkdir(parents=True, exist_ok=True)
# co_act_path = co_act_dir.joinpath('epoch_%d_train_step_%d.pt' % (epoch, train_step))
# torch.save(co_act_hist, co_act_path)
# for i in range(len(avg_nnzs)):
# avg_nnzs[i].reset()
# act_hist[i] /= act_hist[i].sum()
# prob, index = torch.topk(act_hist[i], min(1024, act_hist[i].size(-1)))
# log_str = '| layer {:2d} | top 64 prob {:3.2f} | top 128 prob {:3.2f} | top 256 prob {:3.2f} | top 512 prob {:3.2f} | top 1024 prob {:3.2f}'.format(
# i+1,
# prob[:64].sum().item(),
# prob[:128].sum().item(),
# prob[:256].sum().item(),
# prob[:512].sum().item(),
# prob[:1024].sum().item()
# )
# logging(log_str)
# act_hist[i] = 0.
# co_act_hist[i] = 0.
train_loss
=
0
log_start_time
=
time
.
time
()
if
train_step
%
args
.
eval_interval
==
0
:
val_loss
=
evaluate
(
va_iter
)
# val_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(va_iter)
logging
(
'-'
*
100
)
log_str
=
'| Eval {:3d} at step {:>8d} | time: {:5.2f}s '
\
'| valid loss {:5.2f}'
.
format
(
...
...
@@ -599,11 +511,6 @@ def train():
log_str
+=
' | bpc {:9.5f}'
.
format
(
val_loss
/
math
.
log
(
2
))
else
:
log_str
+=
' | valid ppl {:9.3f}'
.
format
(
math
.
exp
(
val_loss
))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str += ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
logging
(
log_str
)
logging
(
'-'
*
100
)
# Save the model if the validation loss is the best we've seen so far.
...
...
@@ -653,7 +560,6 @@ para_model = model.to(device)
# Run on test data.
test_loss
=
evaluate
(
te_iter
)
# test_loss, eval_avg_nnzs, eval_act_hist, eval_co_act_hist = evaluate(te_iter)
logging
(
'='
*
100
)
if
args
.
dataset
in
[
'enwik8'
,
'text8'
]:
logging
(
'| End of training | test loss {:5.2f} | test bpc {:9.5f}'
.
format
(
...
...
@@ -661,11 +567,4 @@ if args.dataset in ['enwik8', 'text8']:
else
:
logging
(
'| End of training | test loss {:5.2f} | test ppl {:9.3f}'
.
format
(
test_loss
,
math
.
exp
(
test_loss
)))
# final_eval_avg_nnzs = [eval_avg_nnzs[i].avg for i in range(len(eval_avg_nnzs))]
# log_str = ' | avgnnz {:5.2f} | maxnnz {:5.2f}'.format(
# sum(final_eval_avg_nnzs)/len(final_eval_avg_nnzs)*100,
# max(final_eval_avg_nnzs)*100
# )
# logging(log_str)
logging
(
'='
*
100
)
examples/transformer-xl/utils/proj_adaptive_softmax.py
View file @
b0990e4b
...
...
@@ -9,6 +9,10 @@ import torch.nn.functional as F
CUDA_MAJOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
0
])
CUDA_MINOR
=
int
(
torch
.
version
.
cuda
.
split
(
'.'
)[
1
])
class
Projection
(
nn
.
Module
):
def
__init__
(
self
,
out_feat
,
in_feat
):
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_feat
,
in_feat
))
class
ProjectedAdaptiveLogSoftmax
(
nn
.
Module
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
keep_order
=
False
):
...
...
@@ -31,13 +35,13 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
self
.
cluster_bias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
n_clusters
))
self
.
out_layers
=
nn
.
ModuleList
()
self
.
out_projs
=
nn
.
Parameter
List
()
self
.
out_projs
=
nn
.
Module
List
()
if
div_val
==
1
:
for
i
in
range
(
len
(
self
.
cutoffs
)):
if
d_proj
!=
d_embed
:
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_embed
)
)
Projection
(
d_proj
,
d_embed
)
)
else
:
self
.
out_projs
.
append
(
None
)
...
...
@@ -49,7 +53,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
out_projs
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
d_proj
,
d_emb_i
)
)
Projection
(
d_proj
,
d_emb_i
)
)
self
.
out_layers
.
append
(
nn
.
Linear
(
d_emb_i
,
r_idx
-
l_idx
))
...
...
@@ -82,7 +86,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
self
.
n_clusters
==
0
:
logit
=
self
.
_compute_logit
(
hidden
,
self
.
out_layers
[
0
].
weight
,
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
])
self
.
out_layers
[
0
].
bias
,
self
.
out_projs
[
0
]
.
weight
if
self
.
out_projs
[
0
]
is
not
None
else
None
)
nll
=
-
F
.
log_softmax
(
logit
,
dim
=-
1
)
\
.
gather
(
1
,
target
.
unsqueeze
(
1
)).
squeeze
(
1
)
else
:
...
...
@@ -106,7 +110,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
weights
.
append
(
weight_i
)
biases
.
append
(
bias_i
)
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
head_weight
,
head_bias
,
head_proj
=
weights
[
0
],
biases
[
0
],
self
.
out_projs
[
0
]
.
weight
head_logit
=
self
.
_compute_logit
(
hidden
,
head_weight
,
head_bias
,
head_proj
)
head_logprob
=
F
.
log_softmax
(
head_logit
,
dim
=
1
)
...
...
@@ -131,7 +135,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
if
i
==
0
:
logprob_i
=
head_logprob_i
.
gather
(
1
,
target_i
[:,
None
]).
squeeze
(
1
)
else
:
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
weight_i
,
bias_i
,
proj_i
=
weights
[
i
],
biases
[
i
],
self
.
out_projs
[
i
]
.
weight
hidden_i
=
hidden
.
index_select
(
0
,
indices_i
)
...
...
fmoe/layers.py
View file @
b0990e4b
...
...
@@ -36,14 +36,39 @@ class FMoELinear(nn.Module):
'''
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x
=
x
+
bias
return
x
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
out_features={}, bias={}, rank={}'
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
)
...
...
fmoe/megatron.py
View file @
b0990e4b
...
...
@@ -3,18 +3,17 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
'''
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
import
math
from
.transformer
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
from
.utils
import
get_torch_default_comm
class
_MegatronMLP
(
nn
.
Module
):
class
_FakeMegatronMLP
(
nn
.
Module
):
r
'''
A fake mlp without model parallelism for correctness testing
'''
def
__init__
(
self
,
args
,
group
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_hidden_size
)
...
...
fmoe/transformer.py
View file @
b0990e4b
...
...
@@ -44,25 +44,15 @@ class FMoETransformerMLP(FMoE):
d_hidden
=
4096
,
world_size
=
1
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
GELU
()
,
gate
=
NaiveGate
,
top_k
=
2
,
do_lnorm
=
False
,
pre_lnorm
=
False
,
add_residual
=
False
,
expert_dp_comm
=
'none'
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
if
do_lnorm
:
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
else
:
self
.
pre_lnorm
=
None
self
.
add_residual
=
add_residual
self
.
mark_parallel_comm
(
expert_dp_comm
)
def
forward
(
self
,
inp
:
torch
.
Tensor
):
...
...
@@ -72,11 +62,5 @@ class FMoETransformerMLP(FMoE):
'''
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
pre_lnorm
is
not
None
and
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
output
=
super
().
forward
(
inp
)
if
self
.
pre_lnorm
is
not
None
and
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
if
self
.
add_residual
:
output
+=
inp
return
output
.
reshape
(
original_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment