Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
29d81c43
Unverified
Commit
29d81c43
authored
May 08, 2021
by
msbaines
Committed by
GitHub
May 08, 2021
Browse files
[perf] nn.moe: replace einsum with faster equivalent code (#667)
Co-authored-by: @myleott
parent
a9156260
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+6
-8
No files found.
fairscale/nn/moe/top2gate.py
View file @
29d81c43
...
@@ -77,10 +77,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -77,10 +77,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
# Normalize gate probabilities
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
gates1_s
=
(
gates
*
mask1
).
sum
(
dim
=
1
)
# einsum("se,se->s")
mask2_float
=
mask2
.
float
()
gates2_s
=
(
gates
*
mask2
).
sum
(
dim
=
1
)
# einsum("se,se->s")
gates1_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask1_float
)
gates2_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask2_float
)
denom_s
=
gates1_s
+
gates2_s
denom_s
=
gates1_s
+
gates2_s
# Avoid divide-by-zero
# Avoid divide-by-zero
denom_s
=
torch
.
clamp
(
denom_s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
denom_s
=
torch
.
clamp
(
denom_s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
...
@@ -88,12 +86,12 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -88,12 +86,12 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
gates2_s
/=
denom_s
gates2_s
/=
denom_s
# Calculate combine_weights and dispatch_mask
# Calculate combine_weights and dispatch_mask
gates1
=
torch
.
einsum
(
"s,se->se"
,
gates1_s
,
mask1_float
)
gates1
=
gates1_s
.
unsqueeze
(
-
1
)
*
mask1
# einsum("s,se->se"
)
gates2
=
torch
.
einsum
(
"s,se->se"
,
gates2_s
,
mask2_float
)
gates2
=
gates2_s
.
unsqueeze
(
-
1
)
*
mask2
# einsum("s,se->se"
)
locations1_sc
=
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations1_sc
=
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations2_sc
=
one_hot
(
locations2_s
,
num_classes
=
capacity
)
locations2_sc
=
one_hot
(
locations2_s
,
num_classes
=
capacity
)
combine1_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates1
,
locations1_sc
)
combine1_sec
=
gates1
.
unsqueeze
(
2
)
*
locations1_sc
.
unsqueeze
(
1
)
# einsum("se,sc->sec"
)
combine2_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates2
,
locations2_sc
)
combine2_sec
=
gates2
.
unsqueeze
(
2
)
*
locations2_sc
.
unsqueeze
(
1
)
# einsum("se,sc->sec"
)
combine_weights
=
combine1_sec
+
combine2_sec
combine_weights
=
combine1_sec
+
combine2_sec
dispatch_mask
=
combine_weights
.
bool
()
dispatch_mask
=
combine_weights
.
bool
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment