Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
89176e34
Unverified
Commit
89176e34
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[refactor] moe: remove G dimension (#183)
parent
5d4f50fb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
61 deletions
+63
-61
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+8
-9
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+29
-29
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+4
-0
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+12
-13
tests/nn/moe/test_top2gating.py
tests/nn/moe/test_top2gating.py
+10
-10
No files found.
fairscale/nn/moe/moe_layer.py
View file @
89176e34
...
...
@@ -66,29 +66,28 @@ class MOELayer(Base):
p
.
expert
=
True
# type: ignore
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"
g
sec,
g
sm->e
g
cm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
return
torch
.
einsum
(
"
g
sec,e
g
cm->
g
sm"
,
combine_weights
,
expert_output
)
return
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions:
(g)roup,
(s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
==
len
(
self
.
experts
)
,
"group dimension size
must be
equal to
number of local experts"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
%
len
(
self
.
experts
)
==
0
,
"num tokens
must be
order of
number of local experts"
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
# Reshape into S tokens
per group
.
reshaped_input
=
input
[
0
].
reshape
(
shape
[
0
],
-
1
,
shape
[
3
])
# Reshape into S tokens
by dropping sequence dimension
.
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
shape
[
2
])
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
assert
dispatched_input
.
shape
[
1
]
==
len
(
self
.
experts
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
1
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
0
)
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
0
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
fairscale/nn/moe/top2gate.py
View file @
89176e34
...
...
@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Implements Top2Gating on logits."""
gates
=
F
.
softmax
(
logits
,
dim
=
2
)
gates
=
F
.
softmax
(
logits
,
dim
=
1
)
# gates has shape of
G
SE
num_tokens
=
gates
.
shape
[
1
]
num_experts
=
gates
.
shape
[
2
]
# gates has shape of SE
num_tokens
=
gates
.
shape
[
0
]
num_experts
=
gates
.
shape
[
1
]
# capacity = 2S/E
capacity
=
2
*
num_tokens
//
num_experts
assert
num_tokens
%
num_experts
==
0
# Create a mask for 1st's expert per token
indices1_
g
s
=
torch
.
argmax
(
gates
,
dim
=
2
)
mask1
=
F
.
one_hot
(
indices1_
g
s
,
num_classes
=
num_experts
)
indices1_s
=
torch
.
argmax
(
gates
,
dim
=
1
)
mask1
=
F
.
one_hot
(
indices1_s
,
num_classes
=
num_experts
)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise
=
logits
+
gumbel_rsample
(
logits
.
shape
,
device
=
logits
.
device
)
# Replace top-expert with min value
logits_except1
=
logits_w_noise
.
masked_fill
(
mask1
.
bool
(),
float
(
"-inf"
))
indices2_
g
s
=
torch
.
argmax
(
logits_except1
,
dim
=
2
)
mask2
=
F
.
one_hot
(
indices2_
g
s
,
num_classes
=
num_experts
)
indices2_s
=
torch
.
argmax
(
logits_except1
,
dim
=
1
)
mask2
=
F
.
one_hot
(
indices2_s
,
num_classes
=
num_experts
)
# Compute locations in capacity buffer
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
1
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
1
)
-
1
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
0
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
0
)
-
1
# Update 2nd's location by accounting for locations of 1st
locations2
+=
torch
.
sum
(
mask1
,
dim
=
1
,
keepdim
=
True
)
locations2
+=
torch
.
sum
(
mask1
,
dim
=
0
,
keepdim
=
True
)
# Compute l_aux
me
=
torch
.
mean
(
gates
,
dim
=
1
)
ce
=
torch
.
mean
(
mask1
.
float
(),
dim
=
1
)
me
=
torch
.
mean
(
gates
,
dim
=
0
)
ce
=
torch
.
mean
(
mask1
.
float
(),
dim
=
0
)
l_aux
=
torch
.
mean
(
me
*
ce
)
# Remove locations outside capacity from mask
...
...
@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
# Store the capacity location for each token
locations1_
g
s
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
2
)
locations2_
g
s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
2
)
locations1_s
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
1
)
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
mask2_float
=
mask2
.
float
()
gates1_
g
s
=
torch
.
einsum
(
"
g
se,
g
se->
g
s"
,
gates
,
mask1_float
)
gates2_
g
s
=
torch
.
einsum
(
"
g
se,
g
se->
g
s"
,
gates
,
mask2_float
)
denom_
g
s
=
gates1_
g
s
+
gates2_
g
s
gates1_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask1_float
)
gates2_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask2_float
)
denom_s
=
gates1_s
+
gates2_s
# Avoid divide-by-zero
denom_
g
s
=
torch
.
where
(
denom_
gs
>
0
,
denom_gs
,
torch
.
ones_like
(
denom_gs
)
)
gates1_
g
s
/=
denom_
g
s
gates2_
g
s
/=
denom_
g
s
denom_s
=
torch
.
clamp
(
denom_
s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
gates1_s
/=
denom_s
gates2_s
/=
denom_s
# Calculate combine_weights and dispatch_mask
gates1
=
torch
.
einsum
(
"
g
s,
g
se->
g
se"
,
gates1_
g
s
,
mask1_float
)
gates2
=
torch
.
einsum
(
"
g
s,
g
se->
g
se"
,
gates2_
g
s
,
mask2_float
)
locations1_
g
sc
=
F
.
one_hot
(
locations1_
g
s
,
num_classes
=
capacity
)
locations2_
g
sc
=
F
.
one_hot
(
locations2_
g
s
,
num_classes
=
capacity
)
combine1_
g
sec
=
torch
.
einsum
(
"
g
se,
g
sc->
g
sec"
,
gates1
,
locations1_
g
sc
)
combine2_
g
sec
=
torch
.
einsum
(
"
g
se,
g
sc->
g
sec"
,
gates2
,
locations2_
g
sc
)
combine_weights
=
combine1_
g
sec
+
combine2_
g
sec
gates1
=
torch
.
einsum
(
"s,se->se"
,
gates1_s
,
mask1_float
)
gates2
=
torch
.
einsum
(
"s,se->se"
,
gates2_s
,
mask2_float
)
locations1_sc
=
F
.
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations2_sc
=
F
.
one_hot
(
locations2_s
,
num_classes
=
capacity
)
combine1_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates1
,
locations1_sc
)
combine2_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates2
,
locations2_sc
)
combine_weights
=
combine1_sec
+
combine2_sec
dispatch_mask
=
combine_weights
.
bool
()
return
l_aux
,
combine_weights
,
dispatch_mask
...
...
@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module):
self
.
wg
=
torch
.
nn
.
Linear
(
num_experts
,
model_dim
,
bias
=
False
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
logits
=
torch
.
einsum
(
"
g
sm,me ->
g
se"
,
input
,
self
.
wg
.
weight
)
logits
=
torch
.
einsum
(
"sm,me -> se"
,
input
,
self
.
wg
.
weight
)
return
top2gating
(
logits
)
stubs/torch/__init__.pyi
View file @
89176e34
...
...
@@ -37,6 +37,10 @@ from . import version
class dtype:
is_floating_point: builtins.bool
class finfo:
def __init__(self, dtype: dtype): ...
eps: float
class layout: ...
strided : layout = ...
...
...
tests/nn/moe/test_moe_layer.py
View file @
89176e34
...
...
@@ -23,17 +23,16 @@ else:
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
pass
#
dist.init_process_group(backend=dist.Backend.MPI)
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
setup_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
else
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
teardown_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
...
...
@@ -62,7 +61,7 @@ def test_expert_params(device):
def
test_forward
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
...
...
@@ -81,7 +80,7 @@ def test_forward_multi(device):
num_local_experts
=
4
model_dim
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
num_local_experts
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
*
num_local_experts
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
experts
=
[]
for
i
in
range
(
num_local_experts
):
...
...
@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module):
self
.
num_experts
=
num_experts
def
forward
(
self
,
input
):
g
,
s
,
_
=
input
.
shape
s
=
input
.
shape
[
0
]
assert
s
%
self
.
num_experts
==
0
capacity
=
2
*
s
//
self
.
num_experts
output
=
torch
.
zeros
(
g
,
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
output
=
torch
.
zeros
(
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
s
):
output
[
:,
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
output
[
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
return
0.0
,
output
,
output
.
bool
()
...
...
@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def
test_forward_routing
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
...
...
@@ -130,10 +129,10 @@ def test_forward_routing(device):
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
2
]
t
=
input
.
shape
[
1
]
for
i
in
range
(
t
):
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
:,
i
]
*
(
expert
+
1
),
output
[:,
:,
i
])
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
...
...
@@ -142,7 +141,7 @@ def test_backward(device):
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
...
...
tests/nn/moe/test_top2gating.py
View file @
89176e34
...
...
@@ -23,21 +23,21 @@ def test_create_cuda():
def
do_test_forward
(
device
):
torch
.
manual_seed
(
3
)
input
=
torch
.
randn
(
3
,
12
,
4
).
to
(
device
)
input
=
torch
.
randn
(
12
,
4
).
to
(
device
)
gate
=
Top2Gate
(
4
,
6
).
to
(
device
)
capacity
=
2
*
12
//
6
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
assert
combine_weights
.
shape
==
(
3
,
12
,
6
,
4
)
assert
dispatch_mask
.
shape
==
(
3
,
12
,
6
,
4
)
assert
combine_weights
.
shape
==
(
12
,
6
,
4
)
assert
dispatch_mask
.
shape
==
(
12
,
6
,
4
)
assert
torch
.
equal
(
combine_weights
.
bool
(),
dispatch_mask
)
assert
torch
.
all
(
torch
.
sum
(
dispatch_mask
,
axis
=
(
1
,
3
))
<=
capacity
)
assert
torch
.
all
(
torch
.
sum
(
dispatch_mask
,
axis
=
(
0
,
2
))
<=
capacity
)
assert
torch
.
all
(
combine_weights
>=
0.0
)
assert
torch
.
all
(
combine_weights
<=
1.0
)
weights_sum
=
torch
.
sum
(
combine_weights
).
item
()
assert
round
(
weights_sum
)
==
pytest
.
approx
(
weights_sum
)
# For this random seed, we get
36
slots filled.
assert
weights_sum
==
pytest
.
approx
(
36
.0
)
# For this random seed, we get
12
slots filled.
assert
weights_sum
==
pytest
.
approx
(
12
.0
)
def
test_forward_cpu
():
...
...
@@ -53,15 +53,15 @@ def test_forward_cuda():
def
test_top1s
():
num_tokens
=
8
num_experts
=
4
logits
=
torch
.
randn
(
1
,
num_tokens
,
num_experts
)
logits
=
torch
.
randn
(
num_tokens
,
num_experts
)
l_aux
,
_
,
dispatch_mask
=
top2gating
(
logits
)
top1s
=
torch
.
argmax
(
logits
,
dim
=
2
)
top1s
=
torch
.
argmax
(
logits
,
dim
=
1
)
capacity
=
2
*
num_tokens
//
num_experts
ce
=
[
0
]
*
num_experts
locations
=
[
0
]
*
num_tokens
for
i
,
s
in
enumerate
(
top1s
[
0
]
):
for
i
,
s
in
enumerate
(
top1s
):
e
=
s
.
item
()
loc
=
ce
[
e
]
ce
[
e
]
=
loc
+
1
if
ce
[
e
]
<
capacity
:
assert
dispatch_mask
[
0
][
i
][
e
][
loc
]
assert
dispatch_mask
[
i
][
e
][
loc
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment