Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
34df6069
Unverified
Commit
34df6069
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[refactor] moe: cleanup code to be more readable (#186)
parent
317c0945
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
13 deletions
+7
-13
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+5
-11
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+2
-2
No files found.
fairscale/nn/moe/moe_layer.py
View file @
34df6069
...
@@ -66,14 +66,6 @@ class MOELayer(Base):
...
@@ -66,14 +66,6 @@ class MOELayer(Base):
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
num_local_experts
=
len
(
self
.
experts
)
self
.
num_local_experts
=
len
(
self
.
experts
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
return
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
...
@@ -83,8 +75,9 @@ class MOELayer(Base):
...
@@ -83,8 +75,9 @@ class MOELayer(Base):
d_model
=
input
[
0
].
shape
[
2
]
d_model
=
input
[
0
].
shape
[
2
]
# Reshape into S tokens by dropping sequence dimension.
# Reshape into S tokens by dropping sequence dimension.
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
d_model
)
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
d_model
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
self
.
l_aux
,
combine_weights
,
dispatch_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
reshaped_input
)
dispatched_input
=
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
# Re-shape after all-to-all: ecm -> gecm
# Re-shape after all-to-all: ecm -> gecm
dispatched_input
=
dispatched_input
.
reshape
(
self
.
world_size
,
self
.
num_local_experts
,
-
1
,
d_model
)
dispatched_input
=
dispatched_input
.
reshape
(
self
.
world_size
,
self
.
num_local_experts
,
-
1
,
d_model
)
chunks
=
dispatched_input
.
chunk
(
self
.
num_local_experts
,
dim
=
1
)
chunks
=
dispatched_input
.
chunk
(
self
.
num_local_experts
,
dim
=
1
)
...
@@ -92,7 +85,8 @@ class MOELayer(Base):
...
@@ -92,7 +85,8 @@ class MOELayer(Base):
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
expert_output
)
# Re-shape back: gecm -> ecm
# Re-shape back: gecm -> ecm
expert_output
=
expert_output
.
reshape
(
self
.
world_size
*
self
.
num_local_experts
,
-
1
,
d_model
)
expert_output
=
expert_output
.
reshape
(
self
.
world_size
*
self
.
num_local_experts
,
-
1
,
d_model
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
combined_output
=
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
input
[
0
].
shape
)
return
combined_output
.
reshape
(
input
[
0
].
shape
)
fairscale/nn/moe/top2gate.py
View file @
34df6069
...
@@ -112,8 +112,8 @@ class Top2Gate(torch.nn.Module):
...
@@ -112,8 +112,8 @@ class Top2Gate(torch.nn.Module):
def
__init__
(
self
,
model_dim
:
int
,
num_experts
:
int
,)
->
None
:
def
__init__
(
self
,
model_dim
:
int
,
num_experts
:
int
,)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
wg
=
torch
.
nn
.
Linear
(
num_experts
,
model_dim
,
bias
=
False
)
self
.
wg
=
torch
.
nn
.
Linear
(
model_dim
,
num_experts
,
bias
=
False
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
logits
=
torch
.
einsum
(
"sm,me -> se"
,
input
,
self
.
wg
.
weigh
t
)
logits
=
self
.
wg
(
inpu
t
)
return
top2gating
(
logits
)
return
top2gating
(
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment