Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
fa5f45f0
Commit
fa5f45f0
authored
May 29, 2021
by
Rick Ho
Browse files
fix bugs to run megatron with gshard gate
parent
7f6463f0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
28 deletions
+43
-28
cuda/balancing.cu
cuda/balancing.cu
+9
-3
cuda/balancing.cuh
cuda/balancing.cuh
+8
-4
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+2
-2
fmoe/functions.py
fmoe/functions.py
+1
-1
fmoe/gates/gshard_gate.py
fmoe/gates/gshard_gate.py
+4
-2
fmoe/gates/switch_gate.py
fmoe/gates/switch_gate.py
+4
-2
fmoe/gates/utils.py
fmoe/gates/utils.py
+15
-14
No files found.
cuda/balancing.cu
View file @
fa5f45f0
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
/*
/*
* note that due to limit of cuda atomic operator, capacity should be int32
* note that due to limit of cuda atomic operator, capacity should be int32
*/
*/
std
::
vector
<
torch
::
Tensor
>
_limit_by_capacity
(
torch
::
Tensor
_limit_by_capacity
(
torch
::
Tensor
expert_count
,
torch
::
Tensor
capacity
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
capacity
,
long
n_expert
,
long
n_worker
)
{
long
n_expert
,
long
n_worker
)
{
CHECK_INPUT
(
expert_count
);
CHECK_INPUT
(
expert_count
);
...
@@ -16,16 +16,22 @@ std::vector<torch::Tensor> _limit_by_capacity(
...
@@ -16,16 +16,22 @@ std::vector<torch::Tensor> _limit_by_capacity(
capacity
.
data_ptr
<
int
>
(),
capacity
.
data_ptr
<
int
>
(),
expert_count_ack
.
data_ptr
<
long
>
(),
expert_count_ack
.
data_ptr
<
long
>
(),
n_expert
,
n_worker
,
smgr
);
n_expert
,
n_worker
,
smgr
);
return
{
expert_count_ack
}
;
return
expert_count_ack
;
}
}
void
_prune_gate_by_capacity
(
torch
::
Tensor
_prune_gate_by_capacity
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
long
n_expert
,
long
n_worker
)
{
long
n_expert
,
long
n_worker
)
{
auto
smgr
=
getCudaStreamManager
(
expert_count
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
expert_count
.
device
().
index
());
auto
batch_size
=
gate_idx
.
numel
();
auto
batch_size
=
gate_idx
.
numel
();
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
gate_idx
.
dtype
())
.
device
(
gate_idx
.
device
());
auto
new_gate_idx
=
torch
::
empty
(
gate_idx
.
sizes
(),
opt
);
fmoe_cuda_prune_gate_by_capacity_impl
(
fmoe_cuda_prune_gate_by_capacity_impl
(
gate_idx
.
data_ptr
<
long
>
(),
gate_idx
.
data_ptr
<
long
>
(),
new_gate_idx
.
data_ptr
<
long
>
(),
expert_count
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
batch_size
,
n_expert
,
n_worker
,
smgr
);
batch_size
,
n_expert
,
n_worker
,
smgr
);
return
new_gate_idx
;
}
}
cuda/balancing.cuh
View file @
fa5f45f0
...
@@ -31,24 +31,28 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
...
@@ -31,24 +31,28 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
}
}
__global__
__global__
void
prune_gate_by_capacity_kernel
(
long
*
gate_idx
,
int
*
ec
,
void
prune_gate_by_capacity_kernel
(
const
long
*
gate_idx
,
long
*
new_gate_idx
,
int
*
ec
,
const
long
batch_size
,
const
long
n_expert
,
const
long
n_worker
)
{
const
long
batch_size
,
const
long
n_expert
,
const
long
n_worker
)
{
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
batch_size
)
{
if
(
i
<
batch_size
)
{
int
orig_cap
=
atomicSub
(
ec
+
gate_idx
[
i
],
1
);
int
orig_cap
=
atomicSub
(
ec
+
gate_idx
[
i
],
1
);
if
(
orig_cap
<=
0
)
{
if
(
orig_cap
<=
0
)
{
gate_idx
[
i
]
=
-
1
;
new_gate_idx
[
i
]
=
-
1
;
}
else
{
new_gate_idx
[
i
]
=
gate_idx
[
i
];
}
}
}
}
}
}
void
fmoe_cuda_prune_gate_by_capacity_impl
(
long
*
gate_idx
,
int
*
ec
,
void
fmoe_cuda_prune_gate_by_capacity_impl
(
long
*
gate_idx
,
long
*
new_gate_idx
,
int
*
ec
,
const
long
batch_size
,
const
long
n_expert
,
const
long
n_worker
,
const
long
batch_size
,
const
long
n_expert
,
const
long
n_worker
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
dim3
grid_dim
(
CEIL
(
batch_size
,
1024
));
dim3
grid_dim
(
CEIL
(
batch_size
,
1024
));
dim3
block_dim
(
1024
);
dim3
block_dim
(
1024
);
prune_gate_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
stream
(
0
)
>>>
(
prune_gate_by_capacity_kernel
<<<
grid_dim
,
block_dim
,
0
,
smgr
->
stream
(
0
)
>>>
(
gate_idx
,
ec
,
batch_size
,
n_expert
,
n_worker
gate_idx
,
new_gate_idx
,
ec
,
batch_size
,
n_expert
,
n_worker
);
);
smgr
->
sync
(
1
);
smgr
->
sync
(
1
);
}
}
cuda/fmoe_cuda.cpp
View file @
fa5f45f0
...
@@ -43,10 +43,10 @@ std::vector<torch::Tensor> _linear_backward(
...
@@ -43,10 +43,10 @@ std::vector<torch::Tensor> _linear_backward(
);
);
// balancing
// balancing
std
::
vector
<
torch
::
Tensor
>
_limit_by_capacity
(
torch
::
Tensor
_limit_by_capacity
(
torch
::
Tensor
expert_count
,
torch
::
Tensor
capacity
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
capacity
,
long
n_expert
,
long
n_experts
);
long
n_expert
,
long
n_experts
);
void
_prune_gate_by_capacity
(
torch
::
Tensor
_prune_gate_by_capacity
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
long
n_expert
,
long
n_worker
);
long
n_expert
,
long
n_worker
);
...
...
fmoe/functions.py
View file @
fa5f45f0
...
@@ -189,7 +189,7 @@ class MOEGather(Function):
...
@@ -189,7 +189,7 @@ class MOEGather(Function):
global_output_buf
,
global_output_buf
,
local_expert_count
,
local_expert_count
,
global_expert_count
,
global_expert_count
,
local_batch_size
,
pos
.
shape
[
0
]
,
world_size
,
world_size
,
)
)
else
:
else
:
...
...
fmoe/gates/gshard_gate.py
View file @
fa5f45f0
...
@@ -10,7 +10,8 @@ from .utils import limit_by_capacity
...
@@ -10,7 +10,8 @@ from .utils import limit_by_capacity
class
GShardGate
(
NaiveGate
):
class
GShardGate
(
NaiveGate
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
capacity
=
(
1.2
,
2.4
),
random_routing
=
True
):
topk
=
2
,
capacity
=
(
1.2
,
2.4
),
random_routing
=
True
):
assert
topk
==
2
,
'topk should be 2 in gshard'
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
2
)
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
2
)
self
.
capacity
=
capacity
self
.
capacity
=
capacity
self
.
random_routing
=
True
self
.
random_routing
=
True
...
@@ -34,7 +35,8 @@ class GShardGate(NaiveGate):
...
@@ -34,7 +35,8 @@ class GShardGate(NaiveGate):
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
capacity
=
math
.
ceil
(
cap_rate
*
x
.
shape
[
0
])
limit_by_capacity
(
topk_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
_new_lec
,
_new_gec
,
topk_idx
=
limit_by_capacity
(
topk_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
if
self
.
random_routing
:
if
self
.
random_routing
:
rand_routing_prob
=
torch
.
rand
(
gate_score
.
size
(
0
),
device
=
x
.
device
)
rand_routing_prob
=
torch
.
rand
(
gate_score
.
size
(
0
),
device
=
x
.
device
)
...
...
fmoe/gates/switch_gate.py
View file @
fa5f45f0
...
@@ -14,8 +14,9 @@ class SwitchGate(NaiveGate):
...
@@ -14,8 +14,9 @@ class SwitchGate(NaiveGate):
A switch gate implementation
A switch gate implementation
"""
"""
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
topk
=
1
,
switch_eps
=
.
1
,
capacity
=
(
1.2
,
2.4
)):
switch_eps
=
.
1
,
capacity
=
(
1.2
,
2.4
)):
assert
topk
==
1
,
'topk should be 1 in switch'
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
1
)
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
=
1
)
self
.
switch_eps
=
switch_eps
self
.
switch_eps
=
switch_eps
self
.
capacity
=
capacity
self
.
capacity
=
capacity
...
@@ -42,7 +43,8 @@ class SwitchGate(NaiveGate):
...
@@ -42,7 +43,8 @@ class SwitchGate(NaiveGate):
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
cap_rate
=
self
.
capacity
[
0
if
self
.
training
else
1
]
capacity
=
math
.
ceil
(
cap_rate
*
inp
.
shape
[
0
])
capacity
=
math
.
ceil
(
cap_rate
*
inp
.
shape
[
0
])
limit_by_capacity
(
top1_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
_new_lec
,
_new_gec
,
top1_idx
=
limit_by_capacity
(
top1_idx
,
self
.
num_expert
,
self
.
world_size
,
capacity
)
valid_idx
=
top1_idx
[
top1_idx
>
-
1
]
valid_idx
=
top1_idx
[
top1_idx
>
-
1
]
fraction_expert
=
torch
.
scatter_add
(
fraction_expert
=
torch
.
scatter_add
(
...
...
fmoe/gates/utils.py
View file @
fa5f45f0
...
@@ -7,19 +7,20 @@ import fmoe_cuda as fmoe_native
...
@@ -7,19 +7,20 @@ import fmoe_cuda as fmoe_native
def
limit_by_capacity
(
topk_idx
,
num_expert
,
world_size
,
capacity
):
def
limit_by_capacity
(
topk_idx
,
num_expert
,
world_size
,
capacity
):
with
torch
.
no_grad
():
capacity
=
torch
.
ones
(
num_expert
,
dtype
=
torch
.
int32
,
capacity
=
torch
.
ones
(
num_expert
,
dtype
=
torch
.
int32
,
device
=
topk_idx
.
device
)
*
capacity
device
=
topk_idx
.
device
)
*
capacity
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
,
num_expert
,
world_size
,
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
,
num_expert
,
world_size
,
require_pos
=
False
)
require_pos
=
False
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
new_gec
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
num_expert
,
world_size
)
num_expert
,
world_size
)
if
world_size
>
1
:
if
world_size
>
1
:
new_lec
,
=
fmoe_native
.
expert_exchange
(
new_gec
,
num_expert
,
world_size
)
new_lec
,
=
fmoe_native
.
expert_exchange
(
new_gec
,
num_expert
,
world_size
)
else
:
else
:
new_lec
=
new_gec
new_lec
=
new_gec
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
topk_idx
=
fmoe_native
.
prune_gate_by_capacity
(
topk_idx
,
new_lec
.
to
(
torch
.
int32
),
num_expert
,
world_size
)
new_lec
.
to
(
torch
.
int32
),
num_expert
,
world_size
)
return
new_lec
,
new_gec
,
topk_idx
return
new_lec
,
new_gec
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment