Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
34df6069
Unverified
Commit
34df6069
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[refactor] moe: cleanup code to be more readable (#186)
parent
317c0945
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
13 deletions
+7
-13
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+5
-11
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+2
-2
No files found.
fairscale/nn/moe/moe_layer.py
View file @
34df6069
...
@@ -66,14 +66,6 @@ class MOELayer(Base):
...
@@ -66,14 +66,6 @@ class MOELayer(Base):
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
num_local_experts
=
len
(
self
.
experts
)
self
.
num_local_experts
=
len
(
self
.
experts
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
return
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert
len
(
input
[
0
].
shape
)
==
3
,
"input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
...
@@ -83,8 +75,9 @@ class MOELayer(Base):
...
@@ -83,8 +75,9 @@ class MOELayer(Base):
d_model
=
input
[
0
].
shape
[
2
]
d_model
=
input
[
0
].
shape
[
2
]
# Reshape into S tokens by dropping sequence dimension.
# Reshape into S tokens by dropping sequence dimension.
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
d_model
)
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
d_model
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
self
.
l_aux
,
combine_weights
,
dispatch_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
reshaped_input
)
dispatched_input
=
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
# Re-shape after all-to-all: ecm -> gecm
# Re-shape after all-to-all: ecm -> gecm
dispatched_input
=
dispatched_input
.
reshape
(
self
.
world_size
,
self
.
num_local_experts
,
-
1
,
d_model
)
dispatched_input
=
dispatched_input
.
reshape
(
self
.
world_size
,
self
.
num_local_experts
,
-
1
,
d_model
)
chunks
=
dispatched_input
.
chunk
(
self
.
num_local_experts
,
dim
=
1
)
chunks
=
dispatched_input
.
chunk
(
self
.
num_local_experts
,
dim
=
1
)
...
@@ -92,7 +85,8 @@ class MOELayer(Base):
...
@@ -92,7 +85,8 @@ class MOELayer(Base):
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
expert_output
)
# Re-shape back: gecm -> ecm
# Re-shape back: gecm -> ecm
expert_output
=
expert_output
.
reshape
(
self
.
world_size
*
self
.
num_local_experts
,
-
1
,
d_model
)
expert_output
=
expert_output
.
reshape
(
self
.
world_size
*
self
.
num_local_experts
,
-
1
,
d_model
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
combined_output
=
torch
.
einsum
(
"sec,ecm->sm"
,
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
input
[
0
].
shape
)
return
combined_output
.
reshape
(
input
[
0
].
shape
)
fairscale/nn/moe/top2gate.py
View file @
34df6069
...
@@ -112,8 +112,8 @@ class Top2Gate(torch.nn.Module):
...
@@ -112,8 +112,8 @@ class Top2Gate(torch.nn.Module):
def
__init__
(
self
,
model_dim
:
int
,
num_experts
:
int
,)
->
None
:
def
__init__
(
self
,
model_dim
:
int
,
num_experts
:
int
,)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
wg
=
torch
.
nn
.
Linear
(
num_experts
,
model_dim
,
bias
=
False
)
self
.
wg
=
torch
.
nn
.
Linear
(
model_dim
,
num_experts
,
bias
=
False
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
# type: ignore
logits
=
torch
.
einsum
(
"sm,me -> se"
,
input
,
self
.
wg
.
weigh
t
)
logits
=
self
.
wg
(
inpu
t
)
return
top2gating
(
logits
)
return
top2gating
(
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment