Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
414a2f86
Commit
414a2f86
authored
May 19, 2021
by
Rick Ho
Browse files
remove repeate interleave and local scatter
parent
f804a121
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
14 deletions
+31
-14
fmoe/functions.py
fmoe/functions.py
+16
-4
fmoe/layers.py
fmoe/layers.py
+15
-10
No files found.
fmoe/functions.py
View file @
414a2f86
...
@@ -74,6 +74,18 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -74,6 +74,18 @@ def prepare_forward(gate, num_expert, world_size, comm=None):
)
)
def
_local_scatter
(
inp
,
pos
):
inp_buf
=
torch
.
index_select
(
inp
,
0
,
pos
)
return
inp_buf
def
_local_gather
(
inp
,
pos
,
out_batch_size
):
inp_buf
=
torch
.
zeros
(
out_batch_size
,
inp
.
shape
[
-
1
],
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
inp_buf
.
index_copy_
(
0
,
pos
,
inp
)
return
inp_buf
class
MOEScatter
(
Function
):
class
MOEScatter
(
Function
):
r
"""
r
"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
Scatter input samples from [batch x sequences] to contiguous alone experts.
...
@@ -91,7 +103,7 @@ class MOEScatter(Function):
...
@@ -91,7 +103,7 @@ class MOEScatter(Function):
fwd_batch_size
,
fwd_batch_size
,
world_size
,
world_size
,
):
):
(
local_input_buf
,)
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
local_input_buf
=
_
local_scatter
(
inp
,
pos
)
if
world_size
>
1
:
if
world_size
>
1
:
(
global_input_buf
,)
=
fmoe_cuda
.
global_scatter
(
(
global_input_buf
,)
=
fmoe_cuda
.
global_scatter
(
local_input_buf
,
local_input_buf
,
...
@@ -122,7 +134,7 @@ class MOEScatter(Function):
...
@@ -122,7 +134,7 @@ class MOEScatter(Function):
)
)
else
:
else
:
local_grad_in
=
global_grad_in
local_grad_in
=
global_grad_in
(
grad_in
,)
=
fmoe_cuda
.
local_gather
(
local_grad_in
,
pos
)
grad_in
=
_
local_gather
(
local_grad_in
,
pos
,
local_batch_size
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
return
grad_in
,
None
,
None
,
None
,
None
,
None
...
@@ -175,7 +187,7 @@ class MOEGather(Function):
...
@@ -175,7 +187,7 @@ class MOEGather(Function):
)
)
else
:
else
:
local_output_buf
=
global_output_buf
local_output_buf
=
global_output_buf
(
output
,)
=
fmoe_cuda
.
local_gather
(
local_output_buf
,
pos
)
output
=
_
local_gather
(
local_output_buf
,
pos
,
local_batch_size
)
ctx
.
moe_args
=
(
global_output_buf
.
shape
[
0
],
world_size
)
ctx
.
moe_args
=
(
global_output_buf
.
shape
[
0
],
world_size
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
...
@@ -186,7 +198,7 @@ class MOEGather(Function):
...
@@ -186,7 +198,7 @@ class MOEGather(Function):
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
fwd_batch_size
,
world_size
=
ctx
.
moe_args
fwd_batch_size
,
world_size
=
ctx
.
moe_args
(
grad_out_buf
,)
=
fmoe_cuda
.
local_scatter
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
=
_
local_scatter
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
if
world_size
>
1
:
(
global_grad_out_buf
,)
=
fmoe_cuda
.
global_scatter
(
(
global_grad_out_buf
,)
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
grad_out_buf
,
...
...
fmoe/layers.py
View file @
414a2f86
...
@@ -114,12 +114,19 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -114,12 +114,19 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
fwd_batch_size
,
fwd_batch_size
,
)
=
prepare_forward
(
gate
,
num_expert
,
world_size
)
)
=
prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
x
=
MOEScatter
.
apply
(
inp
,
pos
,
inp
,
pos
%
inp
.
shape
[
0
]
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
out_batch_size
=
inp
.
shape
[
0
]
if
len
(
gate
.
shape
)
==
2
:
out_batch_size
*=
gate
.
shape
[
1
]
x
=
MOEGather
.
apply
(
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
x
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
world_size
)
)
return
x
return
x
...
@@ -216,16 +223,14 @@ class FMoE(nn.Module):
...
@@ -216,16 +223,14 @@ class FMoE(nn.Module):
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# to: (BxLxtop_k) x d_model
# TODO: remove repeat_interleave
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
x
=
_fmoe_general_global_forward
(
x
=
_fmoe_general_global_forward
(
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
)
)
# to: (BxL) x top_k x d_model
x
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
x
=
x
.
view
(
inp
.
shape
[
0
],
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x d_model
gate_score
=
gate_score
.
view
(
inp
.
shape
[
0
],
1
,
self
.
top_k
)
gate_score
=
gate_score
.
unsqueeze
(
1
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
if
self
.
mp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment