Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
89176e34
Unverified
Commit
89176e34
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[refactor] moe: remove G dimension (#183)
parent
5d4f50fb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
61 deletions
+63
-61
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+8
-9
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+29
-29
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+4
-0
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+12
-13
tests/nn/moe/test_top2gating.py
tests/nn/moe/test_top2gating.py
+10
-10
No files found.
fairscale/nn/moe/moe_layer.py
View file @
89176e34
...
@@ -66,29 +66,28 @@ class MOELayer(Base):
...
@@ -66,29 +66,28 @@ class MOELayer(Base):
p
.
expert
=
True
# type: ignore
p
.
expert
=
True
# type: ignore
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"
g
sec,
g
sm->e
g
cm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
return
torch
.
einsum
(
"
g
sec,e
g
cm->
g
sm"
,
combine_weights
,
expert_output
)
return
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions:
(g)roup,
(s)equence, (t)oken, (m)odel"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
==
len
(
self
.
experts
)
,
"group dimension size
must be
equal to
number of local experts"
assert
input
[
0
].
shape
[
0
]
%
len
(
self
.
experts
)
==
0
,
"num tokens
must be
order of
number of local experts"
# Implement Algorithm 2 from GShard paper.
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
shape
=
input
[
0
].
shape
# Reshape into S tokens
per group
.
# Reshape into S tokens
by dropping sequence dimension
.
reshaped_input
=
input
[
0
].
reshape
(
shape
[
0
],
-
1
,
shape
[
3
])
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
shape
[
2
])
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
assert
dispatched_input
.
shape
[
1
]
==
len
(
self
.
experts
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
0
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
1
)
expert_outputs
=
[]
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
0
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
return
combined_output
.
reshape
(
shape
)
fairscale/nn/moe/top2gate.py
View file @
89176e34
...
@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
...
@@ -28,36 +28,36 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Implements Top2Gating on logits."""
"""Implements Top2Gating on logits."""
gates
=
F
.
softmax
(
logits
,
dim
=
2
)
gates
=
F
.
softmax
(
logits
,
dim
=
1
)
# gates has shape of
G
SE
# gates has shape of SE
num_tokens
=
gates
.
shape
[
1
]
num_tokens
=
gates
.
shape
[
0
]
num_experts
=
gates
.
shape
[
2
]
num_experts
=
gates
.
shape
[
1
]
# capacity = 2S/E
# capacity = 2S/E
capacity
=
2
*
num_tokens
//
num_experts
capacity
=
2
*
num_tokens
//
num_experts
assert
num_tokens
%
num_experts
==
0
assert
num_tokens
%
num_experts
==
0
# Create a mask for 1st's expert per token
# Create a mask for 1st's expert per token
indices1_
g
s
=
torch
.
argmax
(
gates
,
dim
=
2
)
indices1_s
=
torch
.
argmax
(
gates
,
dim
=
1
)
mask1
=
F
.
one_hot
(
indices1_
g
s
,
num_classes
=
num_experts
)
mask1
=
F
.
one_hot
(
indices1_s
,
num_classes
=
num_experts
)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise
=
logits
+
gumbel_rsample
(
logits
.
shape
,
device
=
logits
.
device
)
logits_w_noise
=
logits
+
gumbel_rsample
(
logits
.
shape
,
device
=
logits
.
device
)
# Replace top-expert with min value
# Replace top-expert with min value
logits_except1
=
logits_w_noise
.
masked_fill
(
mask1
.
bool
(),
float
(
"-inf"
))
logits_except1
=
logits_w_noise
.
masked_fill
(
mask1
.
bool
(),
float
(
"-inf"
))
indices2_
g
s
=
torch
.
argmax
(
logits_except1
,
dim
=
2
)
indices2_s
=
torch
.
argmax
(
logits_except1
,
dim
=
1
)
mask2
=
F
.
one_hot
(
indices2_
g
s
,
num_classes
=
num_experts
)
mask2
=
F
.
one_hot
(
indices2_s
,
num_classes
=
num_experts
)
# Compute locations in capacity buffer
# Compute locations in capacity buffer
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
1
)
-
1
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
0
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
1
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
0
)
-
1
# Update 2nd's location by accounting for locations of 1st
# Update 2nd's location by accounting for locations of 1st
locations2
+=
torch
.
sum
(
mask1
,
dim
=
1
,
keepdim
=
True
)
locations2
+=
torch
.
sum
(
mask1
,
dim
=
0
,
keepdim
=
True
)
# Compute l_aux
# Compute l_aux
me
=
torch
.
mean
(
gates
,
dim
=
1
)
me
=
torch
.
mean
(
gates
,
dim
=
0
)
ce
=
torch
.
mean
(
mask1
.
float
(),
dim
=
1
)
ce
=
torch
.
mean
(
mask1
.
float
(),
dim
=
0
)
l_aux
=
torch
.
mean
(
me
*
ce
)
l_aux
=
torch
.
mean
(
me
*
ce
)
# Remove locations outside capacity from mask
# Remove locations outside capacity from mask
...
@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -65,28 +65,28 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
# Store the capacity location for each token
# Store the capacity location for each token
locations1_
g
s
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
2
)
locations1_s
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
1
)
locations2_
g
s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
2
)
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
# Normalize gate probabilities
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
mask1_float
=
mask1
.
float
()
mask2_float
=
mask2
.
float
()
mask2_float
=
mask2
.
float
()
gates1_
g
s
=
torch
.
einsum
(
"
g
se,
g
se->
g
s"
,
gates
,
mask1_float
)
gates1_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask1_float
)
gates2_
g
s
=
torch
.
einsum
(
"
g
se,
g
se->
g
s"
,
gates
,
mask2_float
)
gates2_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask2_float
)
denom_
g
s
=
gates1_
g
s
+
gates2_
g
s
denom_s
=
gates1_s
+
gates2_s
# Avoid divide-by-zero
# Avoid divide-by-zero
denom_
g
s
=
torch
.
where
(
denom_
gs
>
0
,
denom_gs
,
torch
.
ones_like
(
denom_gs
)
)
denom_s
=
torch
.
clamp
(
denom_
s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
gates1_
g
s
/=
denom_
g
s
gates1_s
/=
denom_s
gates2_
g
s
/=
denom_
g
s
gates2_s
/=
denom_s
# Calculate combine_weights and dispatch_mask
# Calculate combine_weights and dispatch_mask
gates1
=
torch
.
einsum
(
"
g
s,
g
se->
g
se"
,
gates1_
g
s
,
mask1_float
)
gates1
=
torch
.
einsum
(
"s,se->se"
,
gates1_s
,
mask1_float
)
gates2
=
torch
.
einsum
(
"
g
s,
g
se->
g
se"
,
gates2_
g
s
,
mask2_float
)
gates2
=
torch
.
einsum
(
"s,se->se"
,
gates2_s
,
mask2_float
)
locations1_
g
sc
=
F
.
one_hot
(
locations1_
g
s
,
num_classes
=
capacity
)
locations1_sc
=
F
.
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations2_
g
sc
=
F
.
one_hot
(
locations2_
g
s
,
num_classes
=
capacity
)
locations2_sc
=
F
.
one_hot
(
locations2_s
,
num_classes
=
capacity
)
combine1_
g
sec
=
torch
.
einsum
(
"
g
se,
g
sc->
g
sec"
,
gates1
,
locations1_
g
sc
)
combine1_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates1
,
locations1_sc
)
combine2_
g
sec
=
torch
.
einsum
(
"
g
se,
g
sc->
g
sec"
,
gates2
,
locations2_
g
sc
)
combine2_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates2
,
locations2_sc
)
combine_weights
=
combine1_
g
sec
+
combine2_
g
sec
combine_weights
=
combine1_sec
+
combine2_sec
dispatch_mask
=
combine_weights
.
bool
()
dispatch_mask
=
combine_weights
.
bool
()
return
l_aux
,
combine_weights
,
dispatch_mask
return
l_aux
,
combine_weights
,
dispatch_mask
...
@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module):
...
@@ -115,5 +115,5 @@ class Top2Gate(torch.nn.Module):
self
.
wg
=
torch
.
nn
.
Linear
(
num_experts
,
model_dim
,
bias
=
False
)
self
.
wg
=
torch
.
nn
.
Linear
(
num_experts
,
model_dim
,
bias
=
False
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
logits
=
torch
.
einsum
(
"
g
sm,me ->
g
se"
,
input
,
self
.
wg
.
weight
)
logits
=
torch
.
einsum
(
"sm,me -> se"
,
input
,
self
.
wg
.
weight
)
return
top2gating
(
logits
)
return
top2gating
(
logits
)
stubs/torch/__init__.pyi
View file @
89176e34
...
@@ -37,6 +37,10 @@ from . import version
...
@@ -37,6 +37,10 @@ from . import version
class dtype:
class dtype:
is_floating_point: builtins.bool
is_floating_point: builtins.bool
class finfo:
def __init__(self, dtype: dtype): ...
eps: float
class layout: ...
class layout: ...
strided : layout = ...
strided : layout = ...
...
...
tests/nn/moe/test_moe_layer.py
View file @
89176e34
...
@@ -23,18 +23,17 @@ else:
...
@@ -23,18 +23,17 @@ else:
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
pass
#
dist.init_process_group(backend=dist.Backend.MPI)
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
setup_module
(
module
):
def
setup_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
else
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
teardown_module
(
module
):
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
...
@@ -62,7 +61,7 @@ def test_expert_params(device):
...
@@ -62,7 +61,7 @@ def test_expert_params(device):
def
test_forward
(
device
):
def
test_forward
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
# Use identity matrix
...
@@ -81,7 +80,7 @@ def test_forward_multi(device):
...
@@ -81,7 +80,7 @@ def test_forward_multi(device):
num_local_experts
=
4
num_local_experts
=
4
model_dim
=
4
model_dim
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
num_local_experts
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
*
num_local_experts
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
experts
=
[]
experts
=
[]
for
i
in
range
(
num_local_experts
):
for
i
in
range
(
num_local_experts
):
...
@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module):
...
@@ -106,12 +105,12 @@ class RoundRobinGate(torch.nn.Module):
self
.
num_experts
=
num_experts
self
.
num_experts
=
num_experts
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
g
,
s
,
_
=
input
.
shape
s
=
input
.
shape
[
0
]
assert
s
%
self
.
num_experts
==
0
assert
s
%
self
.
num_experts
==
0
capacity
=
2
*
s
//
self
.
num_experts
capacity
=
2
*
s
//
self
.
num_experts
output
=
torch
.
zeros
(
g
,
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
output
=
torch
.
zeros
(
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
s
):
for
i
in
range
(
s
):
output
[
:,
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
output
[
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
return
0.0
,
output
,
output
.
bool
()
return
0.0
,
output
,
output
.
bool
()
...
@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
...
@@ -120,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def
test_forward_routing
(
device
):
def
test_forward_routing
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
# Use scaling matrix (each rank has a different scale)
...
@@ -130,10 +129,10 @@ def test_forward_routing(device):
...
@@ -130,10 +129,10 @@ def test_forward_routing(device):
output
=
moe
(
input
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
2
]
t
=
input
.
shape
[
1
]
for
i
in
range
(
t
):
for
i
in
range
(
t
):
expert
=
i
%
num_experts
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
:,
i
]
*
(
expert
+
1
),
output
[:,
:,
i
])
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
mpi
...
@@ -142,7 +141,7 @@ def test_backward(device):
...
@@ -142,7 +141,7 @@ def test_backward(device):
loss
=
torch
.
nn
.
MSELoss
()
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
# Use identity matrix
...
...
tests/nn/moe/test_top2gating.py
View file @
89176e34
...
@@ -23,21 +23,21 @@ def test_create_cuda():
...
@@ -23,21 +23,21 @@ def test_create_cuda():
def
do_test_forward
(
device
):
def
do_test_forward
(
device
):
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
3
)
input
=
torch
.
randn
(
3
,
12
,
4
).
to
(
device
)
input
=
torch
.
randn
(
12
,
4
).
to
(
device
)
gate
=
Top2Gate
(
4
,
6
).
to
(
device
)
gate
=
Top2Gate
(
4
,
6
).
to
(
device
)
capacity
=
2
*
12
//
6
capacity
=
2
*
12
//
6
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
assert
combine_weights
.
shape
==
(
3
,
12
,
6
,
4
)
assert
combine_weights
.
shape
==
(
12
,
6
,
4
)
assert
dispatch_mask
.
shape
==
(
3
,
12
,
6
,
4
)
assert
dispatch_mask
.
shape
==
(
12
,
6
,
4
)
assert
torch
.
equal
(
combine_weights
.
bool
(),
dispatch_mask
)
assert
torch
.
equal
(
combine_weights
.
bool
(),
dispatch_mask
)
assert
torch
.
all
(
torch
.
sum
(
dispatch_mask
,
axis
=
(
1
,
3
))
<=
capacity
)
assert
torch
.
all
(
torch
.
sum
(
dispatch_mask
,
axis
=
(
0
,
2
))
<=
capacity
)
assert
torch
.
all
(
combine_weights
>=
0.0
)
assert
torch
.
all
(
combine_weights
>=
0.0
)
assert
torch
.
all
(
combine_weights
<=
1.0
)
assert
torch
.
all
(
combine_weights
<=
1.0
)
weights_sum
=
torch
.
sum
(
combine_weights
).
item
()
weights_sum
=
torch
.
sum
(
combine_weights
).
item
()
assert
round
(
weights_sum
)
==
pytest
.
approx
(
weights_sum
)
assert
round
(
weights_sum
)
==
pytest
.
approx
(
weights_sum
)
# For this random seed, we get
36
slots filled.
# For this random seed, we get
12
slots filled.
assert
weights_sum
==
pytest
.
approx
(
36
.0
)
assert
weights_sum
==
pytest
.
approx
(
12
.0
)
def
test_forward_cpu
():
def
test_forward_cpu
():
...
@@ -53,15 +53,15 @@ def test_forward_cuda():
...
@@ -53,15 +53,15 @@ def test_forward_cuda():
def
test_top1s
():
def
test_top1s
():
num_tokens
=
8
num_tokens
=
8
num_experts
=
4
num_experts
=
4
logits
=
torch
.
randn
(
1
,
num_tokens
,
num_experts
)
logits
=
torch
.
randn
(
num_tokens
,
num_experts
)
l_aux
,
_
,
dispatch_mask
=
top2gating
(
logits
)
l_aux
,
_
,
dispatch_mask
=
top2gating
(
logits
)
top1s
=
torch
.
argmax
(
logits
,
dim
=
2
)
top1s
=
torch
.
argmax
(
logits
,
dim
=
1
)
capacity
=
2
*
num_tokens
//
num_experts
capacity
=
2
*
num_tokens
//
num_experts
ce
=
[
0
]
*
num_experts
ce
=
[
0
]
*
num_experts
locations
=
[
0
]
*
num_tokens
locations
=
[
0
]
*
num_tokens
for
i
,
s
in
enumerate
(
top1s
[
0
]
):
for
i
,
s
in
enumerate
(
top1s
):
e
=
s
.
item
()
e
=
s
.
item
()
loc
=
ce
[
e
]
loc
=
ce
[
e
]
ce
[
e
]
=
loc
+
1
ce
[
e
]
=
loc
+
1
if
ce
[
e
]
<
capacity
:
if
ce
[
e
]
<
capacity
:
assert
dispatch_mask
[
0
][
i
][
e
][
loc
]
assert
dispatch_mask
[
i
][
e
][
loc
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment