Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1aced6d8
Commit
1aced6d8
authored
Apr 27, 2021
by
Rick Ho
Browse files
balancing cuda code
parent
6cb6bbe4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
15 deletions
+86
-15
cuda/balancing.cu
cuda/balancing.cu
+17
-0
cuda/balancing.cuh
cuda/balancing.cuh
+31
-0
fmoe/functions.py
fmoe/functions.py
+23
-14
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+15
-1
No files found.
cuda/balancing.cu
0 → 100644
View file @
1aced6d8
#include "balancing.cuh"
#include <torch/extension.h>
/*
* note that due to limit of cuda atomic operator, capacity should be int32
*/
std
::
vector
<
torch
::
Tensor
>
_limit_by_capacity
(
torch
::
Tensor
expert_count
,
torch
::
Tensor
capacity
,
long
n_expert
,
long
n_experts
)
{
auto
expert_count_ack
=
torch
::
empty_like
(
expert_count
);
auto
smgr
=
getCudaStreamManager
(
expert_count
.
device
().
index
());
fmoe_cuda_limit_by_capacity_impl
(
expert_count
.
data_ptr
<
long
>
(),
capacity
.
data_ptr
<
int
>
(),
expert_count_ack
.
data_ptr
<
long
>
(),
n_expert
,
n_workers
,
smgr
);
}
cuda/balancing.cuh
0 → 100644
View file @
1aced6d8
#include "stream_manager.h"
#include "utils/fmoe_utils.h"
#include <cuda.h>
__global__
void
limit_by_capacity_kernel
(
const
long
*
ec
,
int
*
cap
,
long
*
eca
,
const
long
n_expert
,
const
long
n_worker
)
{
int
eid
=
blockIdx
.
y
;
int
wid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
wid
<
n_worker
)
{
int
proposal
=
ec
[
wid
*
n_expert
+
eid
];
int
cap_left
=
atomicSub
(
cap
+
eid
,
proposal
);
if
(
cap_left
>=
proposal
)
{
eca
[
wid
*
n_expert
+
eid
]
=
proposal
;
}
else
if
(
cap_left
>=
0
)
{
eca
[
wid
*
n_expert
+
eid
]
=
cap_left
;
}
else
{
eca
[
wid
*
n_expert
+
eid
]
=
0
;
}
}
}
void
fmoe_cuda_limit_by_capacity_impl
(
const
long
*
ec
,
int
*
cap
,
long
*
eca
,
const
long
n_expert
,
const
long
n_worker
,
CudaStreamManager
*
smgr
)
{
dim3
grid_dim
(
CEIL
(
n_worker
,
1024
),
n_expert
);
dim3
block_dim
(
1024
);
limit_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
stream
(
0
)
>>>
(
ec
,
cap
,
eca
,
n_expert
,
n_worker
);
smgr
->
sync
(
1
);
}
fmoe/functions.py
View file @
1aced6d8
...
@@ -10,7 +10,27 @@ import fmoe_cuda
...
@@ -10,7 +10,27 @@ import fmoe_cuda
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
def
moe_prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
comm
):
# TODO: support -1 in gate, which means ignore this input
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),),
gate_count
)
if
world_size
>
1
:
(
global_expert_count
,)
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
else
:
global_expert_count
=
local_expert_count
return
pos
,
local_expert_count
,
global_expert_count
def
prepare_forward
(
gate
,
num_expert
,
world_size
,
comm
=
None
):
r
"""
r
"""
Prepare necessary information from gate output for MoE computation.
Prepare necessary information from gate output for MoE computation.
...
@@ -26,20 +46,9 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
...
@@ -26,20 +46,9 @@ def moe_prepare_forward(gate, num_expert, world_size, comm=None):
comm
=
get_torch_default_comm
()
comm
=
get_torch_default_comm
()
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
fmoe_cuda
.
ensure_nccl
(
comm
,
gate
)
pos
,
local_expert_count
,
global_expert_count
=
count_by_gate
(
gate
,
num_expert
,
world_size
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),),
gate_count
)
if
world_size
>
1
:
(
global_expert_count
,)
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
else
:
global_expert_count
=
local_expert_count
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
)
num_expert
).
sum
(
dim
=
0
)
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
...
...
fmoe/gates/gshard_gate.py
View file @
1aced6d8
...
@@ -4,6 +4,8 @@ Balanced gate with GShard's policy (Google, 2020)
...
@@ -4,6 +4,8 @@ Balanced gate with GShard's policy (Google, 2020)
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
.naive_gate
import
NaiveGate
from
fmoe.functions
import
count_by_gate
import
fmoe_cuda
as
fmoe_native
class
GShardGate
(
NaiveGate
):
class
GShardGate
(
NaiveGate
):
...
@@ -27,6 +29,18 @@ class GShardGate(NaiveGate):
...
@@ -27,6 +29,18 @@ class GShardGate(NaiveGate):
loss
=
torch
.
mean
(
c_e
*
m_e
)
*
(
self
.
num_expert
**
2
)
loss
=
torch
.
mean
(
c_e
*
m_e
)
*
(
self
.
num_expert
**
2
)
self
.
set_loss
(
loss
)
self
.
set_loss
(
loss
)
# TODO: capacity limit
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
torch
.
ones
(
self
.
num_expert
,
dtype
=
torch
.
int32
)
capacity
*=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
pos
,
lec
,
gec
=
count_by_gate
(
gate
,
self
.
num_expert
,
self
.
world_size
)
new_gec
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
if
self
.
world_size
>
1
:
new_lec
=
fmoe_native
.
expert_exchange
(
new_gec
,
self
.
num_expert
,
self
.
world_size
)
else
:
new_lec
=
new_gec
# TODO: re-assign gate
return
topk_idx
,
topk_val
return
topk_idx
,
topk_val
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment