Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
89176e34
Unverified
Commit
89176e34
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[refactor] moe: remove G dimension (#183)
parent
5d4f50fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
61 deletions
+63
-61
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+8
-9
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+29
-29
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+4
-0
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+12
-13
tests/nn/moe/test_top2gating.py
tests/nn/moe/test_top2gating.py
+10
-10
No files found.
fairscale/nn/moe/moe_layer.py
View file @
89176e34
...
...
@@ -66,29 +66,28 @@ class MOELayer(Base):
p
.
expert
=
True
# type: ignore
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"
g
sec,
g
sm->e
g
cm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
return
torch
.
einsum
(
"
g
sec,e
g
cm->
g
sm"
,
combine_weights
,
expert_output
)
return
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions:
(g)roup,
(s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
==
len
(
self
.
experts
)
,
"group dimension size
must be
equal to
number of local experts"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
%
len
(
self
.
experts
)
==
0
,
"num tokens
must be
order of
number of local experts"
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
# Reshape into S tokens
per group
.
reshaped_input
=
input
[
0
].
reshape
(
shape
[
0
],
-
1
,
shape
[
3
])
# Reshape into S tokens
by dropping sequence dimension
.
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
shape
[
2
])
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
assert
dispatched_input
.
shape
[
1
]
==
len
(
self
.
experts
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
1
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
0
)
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
0
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
fairscale/nn/moe/top2gate.py
View file @
89176e34
...
...
@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Implements Top2Gating on logits."""
gates
=
F
.
softmax
(
logits
,
dim
=
2
)
gates
=
F
.
softmax
(
logits
,
dim
=
1
)
# gates has shape of
G
SE
num_tokens
=
gates
.
shape
[
1
]
num_experts
=
gates
.
shape
[
2
]
# gates has shape of SE
num_tokens
=
gates
.
shape
[
0
]
num_experts
=
gates
.
shape
[
1
]
# capacity = 2S/E
capacity
=
2
*
num_tokens
//
num_experts
assert
num_tokens
%
num_experts
==
0
# Create a mask for 1st's expert per token
indices1_
g
s
=
torch
.
argmax
(
gates
,
dim
=
2
)
mask1
=
F
.
one_hot
(
indices1_
g
s
,
num_classes
=
num_experts
)
indices1_s
=
torch
.
argmax
(
gates
,
dim
=
1
)
mask1
=
F
.
one_hot
(
indices1_s
,
num_classes
=
num_experts
)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise
=
logits
+
gumbel_rsample
(
logits
.
shape
,
device
=
logits
.
device
)
# Replace top-expert with min value
logits_except1
=
logits_w_noise
.
masked_fill
(
mask1
.
bool
(),
float
(
"-inf"
))
indices2_
g
s
=
torch
.
argmax
(
logits_except1
,
dim
=
2
)
mask2
=
F
.
one_hot
(
indices2_
g
s
,
num_classes
=
num_experts
)
indices2_s
=
torch
.
argmax
(
logits_except1
,
dim
=
1
)
mask2
=
F
.
one_hot
(
indices2_s
,
num_classes
=
num_experts
)
# Compute locations in capacity buffer
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
1
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
1
)
-
1
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
0
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
0
)
-
1
# Update 2nd's location by accounting for locations of 1st
locations2
+=
torch
.
sum
(
mask1
,
dim
=
1
,
keepdim
=
True
)
locations2
+=
torch
.
sum
(
mask1
,
dim
=
0
,
keepdim
=
True
)
# Compute l_aux
me
=
torch
.
mean
(
gates
,
dim
=
1
)
ce
=
torch
.
mean
(
mask1
.
float
(),
dim
=
1
)
me
=
torch
.
mean
(
gates
,
dim
=
0
)
ce
=
torch
.
mean
(
mask1
.
float
(),
dim
=
0
)
l_aux
=
torch
.
mean
(
me
*
ce
)
# Remove locations outside capacity from mask
...
...
@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
# Store the capacity location for each token
locations1_
g
s
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
2
)
locations2_
g
s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
2
)
locations1_s
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
1
)
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
mask2_float
=
mask2
.
float
()
gates1_
g
s
=
torch
.
einsum
(
"
g
se,
g
se->
g
s"
,
gates
,
mask1_float
)
gates2_
g
s
=
torch
.
einsum
(
"
g
se,
g
se->
g
s"
,
gates
,
mask2_float
)
denom_
g
s
=
gates1_
g
s
+
gates2_
g
s
gates1_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask1_float
)
gates2_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask2_float
)
denom_s
=
gates1_s
+
gates2_s
# Avoid divide-by-zero
denom_
g
s
=
torch
.
where
(
denom_
gs
>
0
,
denom_gs
,
torch
.
ones_like
(
denom_gs
)
)
gates1_
g
s
/=
denom_
g
s
gates2_
g
s
/=
denom_
g
s
denom_s
=
torch
.
clamp
(
denom_
s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
gates1_s
/=
denom_s
gates2_s
/=
denom_s
# Calculate combine_weights and dispatch_mask
gates1
=
torch
.
einsum
(
"
g
s,
g
se->
g
se"
,
gates1_
g
s
,
mask1_float
)
gates2
=
torch
.
einsum
(
"
g
s,
g
se->
g
se"
,
gates2_
g
s
,
mask2_float
)
locations1_
g
sc
=
F
.
one_hot
(
locations1_
g
s
,
num_classes
=
capacity
)
locations2_
g
sc
=
F
.
one_hot
(
locations2_
g
s
,
num_classes
=
capacity
)
combine1_
g
sec
=
torch
.
einsum
(
"
g
se,
g
sc->
g
sec"
,
gates1
,
locations1_
g
sc
)
combine2_
g
sec
=
torch
.
einsum
(
"
g
se,
g
sc->
g
sec"
,
gates2
,
locations2_
g
sc
)
combine_weights
=
combine1_
g
sec
+
combine2_
g
sec
gates1
=
torch
.
einsum
(
"s,se->se"
,
gates1_s
,
mask1_float
)
gates2
=
torch
.
einsum
(
"s,se->se"
,
gates2_s
,
mask2_float
)
locations1_sc
=
F
.
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations2_sc
=
F
.
one_hot
(
locations2_s
,
num_classes
=
capacity
)
combine1_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates1
,
locations1_sc
)
combine2_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates2
,
locations2_sc
)
combine_weights
=
combine1_sec
+
combine2_sec
dispatch_mask
=
combine_weights
.
bool
()
return
l_aux
,
combine_weights
,
dispatch_mask
...
...
@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module):
self
.
wg
=
torch
.
nn
.
Linear
(
num_experts
,
model_dim
,
bias
=
False
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
logits
=
torch
.
einsum
(
"
g
sm,me ->
g
se"
,
input
,
self
.
wg
.
weight
)
logits
=
torch
.
einsum
(
"sm,me -> se"
,
input
,
self
.
wg
.
weight
)
return
top2gating
(
logits
)
stubs/torch/__init__.pyi
View file @
89176e34
...
...
@@ -37,6 +37,10 @@ from . import version
class dtype:
is_floating_point: builtins.bool
class finfo:
def __init__(self, dtype: dtype): ...
eps: float
class layout: ...
strided : layout = ...
...
...
tests/nn/moe/test_moe_layer.py
View file @
89176e34
...
...
@@ -23,18 +23,17 @@ else:
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
pass
#
dist.init_process_group(backend=dist.Backend.MPI)
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
setup_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
else
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
...
...
@@ -62,7 +61,7 @@ def test_expert_params(device):
def
test_forward
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
...
...
@@ -81,7 +80,7 @@ def test_forward_multi(device):
num_local_experts
=
4
model_dim
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
num_local_experts
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
*
num_local_experts
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
experts
=
[]
for
i
in
range
(
num_local_experts
):
...
...
@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module):
self
.
num_experts
=
num_experts
def
forward
(
self
,
input
):
g
,
s
,
_
=
input
.
shape
s
=
input
.
shape
[
0
]
assert
s
%
self
.
num_experts
==
0
capacity
=
2
*
s
//
self
.
num_experts
output
=
torch
.
zeros
(
g
,
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
output
=
torch
.
zeros
(
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
s
):
output
[
:,
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
output
[
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
return
0.0
,
output
,
output
.
bool
()
...
...
@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def
test_forward_routing
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
...
...
@@ -130,10 +129,10 @@ def test_forward_routing(device):
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
2
]
t
=
input
.
shape
[
1
]
for
i
in
range
(
t
):
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
:,
i
]
*
(
expert
+
1
),
output
[:,
:,
i
])
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
...
...
@@ -142,7 +141,7 @@ def test_backward(device):
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
...
...
tests/nn/moe/test_top2gating.py
View file @
89176e34
...
...
@@ -23,21 +23,21 @@ def test_create_cuda():
def
do_test_forward
(
device
):
torch
.
manual_seed
(
3
)
input
=
torch
.
randn
(
3
,
12
,
4
).
to
(
device
)
input
=
torch
.
randn
(
12
,
4
).
to
(
device
)
gate
=
Top2Gate
(
4
,
6
).
to
(
device
)
capacity
=
2
*
12
//
6
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
assert
combine_weights
.
shape
==
(
3
,
12
,
6
,
4
)
assert
dispatch_mask
.
shape
==
(
3
,
12
,
6
,
4
)
assert
combine_weights
.
shape
==
(
12
,
6
,
4
)
assert
dispatch_mask
.
shape
==
(
12
,
6
,
4
)
assert
torch
.
equal
(
combine_weights
.
bool
(),
dispatch_mask
)
assert
torch
.
all
(
torch
.
sum
(
dispatch_mask
,
axis
=
(
1
,
3
))
<=
capacity
)
assert
torch
.
all
(
torch
.
sum
(
dispatch_mask
,
axis
=
(
0
,
2
))
<=
capacity
)
assert
torch
.
all
(
combine_weights
>=
0.0
)
assert
torch
.
all
(
combine_weights
<=
1.0
)
weights_sum
=
torch
.
sum
(
combine_weights
).
item
()
assert
round
(
weights_sum
)
==
pytest
.
approx
(
weights_sum
)
# For this random seed, we get
36
slots filled.
assert
weights_sum
==
pytest
.
approx
(
36
.0
)
# For this random seed, we get
12
slots filled.
assert
weights_sum
==
pytest
.
approx
(
12
.0
)
def
test_forward_cpu
():
...
...
@@ -53,15 +53,15 @@ def test_forward_cuda():
def
test_top1s
():
num_tokens
=
8
num_experts
=
4
logits
=
torch
.
randn
(
1
,
num_tokens
,
num_experts
)
logits
=
torch
.
randn
(
num_tokens
,
num_experts
)
l_aux
,
_
,
dispatch_mask
=
top2gating
(
logits
)
top1s
=
torch
.
argmax
(
logits
,
dim
=
2
)
top1s
=
torch
.
argmax
(
logits
,
dim
=
1
)
capacity
=
2
*
num_tokens
//
num_experts
ce
=
[
0
]
*
num_experts
locations
=
[
0
]
*
num_tokens
for
i
,
s
in
enumerate
(
top1s
[
0
]
):
for
i
,
s
in
enumerate
(
top1s
):
e
=
s
.
item
()
loc
=
ce
[
e
]
ce
[
e
]
=
loc
+
1
if
ce
[
e
]
<
capacity
:
assert
dispatch_mask
[
0
][
i
][
e
][
loc
]
assert
dispatch_mask
[
i
][
e
][
loc
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment