Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
5e0af68d
"server/text_generation_server/models/flash_starcoder2.py" did not exist on "e3e487dc711449c23826cfe1d74786f71309d6bd"
Commit
5e0af68d
authored
Jan 29, 2021
by
Rick Ho
Browse files
split fmoe functions
parent
ad07f07a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
0 deletions
+102
-0
fmoe/fmoe_functions.py
fmoe/fmoe_functions.py
+102
-0
No files found.
fmoe/fmoe_functions.py
0 → 100644
View file @
5e0af68d
import
torch
from
torch.autograd
import
Function
import
fmoe_cuda
def
moe_prepare_forward
(
gate
,
num_expert
,
world_size
):
fmoe_cuda
.
ensure_nccl
(
torch
.
distributed
.
distributed_c10d
.
_default_pg
,
gate
)
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
]
*
world_size
,
device
=
weight
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
).
cpu
()
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
return
(
pos
,
local_expert_count
.
cpu
(),
global_expert_count
.
cpu
(),
fwd_expert_count
.
cpu
(),
fwd_batch_size
)
class
MOEScatter
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
):
local_input_buf
,
=
fmoe_cuda
.
local_gather
(
inp
,
pos
)
if
world_size
>
1
:
global_input_buf
,
=
moe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
else
:
global_input_buf
=
local_input_buf
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
return
global_input_buf
@
staticmethod
def
backward
(
ctx
,
global_grad_in
):
(
pos
,
local_expert_count
,
global_expert_count
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
local_batch_size
,
world_size
)
=
ctx
.
moe_args
if
world_size
>
1
:
local_grad_in
,
=
moe_cuda
.
global_gather
(
global_grad_out
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
else
:
local_grad_in
=
global_grad_in
grad_in
,
=
moe_cuda
.
local_scatter
(
local_grad_in
,
pos
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
class
MOELinear
(
Function
):
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
global_output_buf
,
=
moe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
input_buf
,
weight
,
fwd_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
ome_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
)
return
grad_inp_buf
,
grad_weight
,
None
class
MOEGather
(
Function
):
@
staticmethod
def
forward
(
ctx
,
global_output_buf
,
pos
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
):
if
world_size
>
1
:
local_output_buf
,
=
moe_cuda
.
global_gather
(
global_output_buf
,
local_expert_count
,
global_expert_count
,
local_batch_size
,
world_size
)
else
:
local_output_buf
=
global_output_buf
output
,
=
moe_cuda
.
local_scatter
(
local_output_buf
,
pos
)
ctx
.
moe_args
=
fwd_batch_size
,
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
=
moe_cuda
.
local_gather
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
global_grad_out_buf
,
=
moe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
else
:
global_grad_out_buf
=
grad_out_buf
return
global_grad_out_buf
,
None
,
None
,
None
,
None
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment