Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
29d81c43
Unverified
Commit
29d81c43
authored
May 08, 2021
by
msbaines
Committed by
GitHub
May 08, 2021
Browse files
[perf] nn.moe: replace einsum with faster equivalent code (#667)
Co-authored-by: @myleott
parent
a9156260
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+6
-8
No files found.
fairscale/nn/moe/top2gate.py
View file @
29d81c43
...
@@ -77,10 +77,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -77,10 +77,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
locations2_s
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
1
)
# Normalize gate probabilities
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
gates1_s
=
(
gates
*
mask1
).
sum
(
dim
=
1
)
# einsum("se,se->s")
mask2_float
=
mask2
.
float
()
gates2_s
=
(
gates
*
mask2
).
sum
(
dim
=
1
)
# einsum("se,se->s")
gates1_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask1_float
)
gates2_s
=
torch
.
einsum
(
"se,se->s"
,
gates
,
mask2_float
)
denom_s
=
gates1_s
+
gates2_s
denom_s
=
gates1_s
+
gates2_s
# Avoid divide-by-zero
# Avoid divide-by-zero
denom_s
=
torch
.
clamp
(
denom_s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
denom_s
=
torch
.
clamp
(
denom_s
,
min
=
torch
.
finfo
(
denom_s
.
dtype
).
eps
)
...
@@ -88,12 +86,12 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -88,12 +86,12 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
gates2_s
/=
denom_s
gates2_s
/=
denom_s
# Calculate combine_weights and dispatch_mask
# Calculate combine_weights and dispatch_mask
gates1
=
torch
.
einsum
(
"s,se->se"
,
gates1_s
,
mask1_float
)
gates1
=
gates1_s
.
unsqueeze
(
-
1
)
*
mask1
# einsum("s,se->se"
)
gates2
=
torch
.
einsum
(
"s,se->se"
,
gates2_s
,
mask2_float
)
gates2
=
gates2_s
.
unsqueeze
(
-
1
)
*
mask2
# einsum("s,se->se"
)
locations1_sc
=
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations1_sc
=
one_hot
(
locations1_s
,
num_classes
=
capacity
)
locations2_sc
=
one_hot
(
locations2_s
,
num_classes
=
capacity
)
locations2_sc
=
one_hot
(
locations2_s
,
num_classes
=
capacity
)
combine1_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates1
,
locations1_sc
)
combine1_sec
=
gates1
.
unsqueeze
(
2
)
*
locations1_sc
.
unsqueeze
(
1
)
# einsum("se,sc->sec"
)
combine2_sec
=
torch
.
einsum
(
"se,sc->sec"
,
gates2
,
locations2_sc
)
combine2_sec
=
gates2
.
unsqueeze
(
2
)
*
locations2_sc
.
unsqueeze
(
1
)
# einsum("se,sc->sec"
)
combine_weights
=
combine1_sec
+
combine2_sec
combine_weights
=
combine1_sec
+
combine2_sec
dispatch_mask
=
combine_weights
.
bool
()
dispatch_mask
=
combine_weights
.
bool
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment