Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
c5cbd64b
Commit
c5cbd64b
authored
Oct 26, 2021
by
Rick Ho
Browse files
swipe test not passing
parent
fe5a9cda
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
133 additions
and
32 deletions
+133
-32
cuda/balancing.cu
cuda/balancing.cu
+31
-10
fmoe/gates/swipe_gate.py
fmoe/gates/swipe_gate.py
+10
-8
fmoe/layers.py
fmoe/layers.py
+0
-3
tests/test_ddp.py
tests/test_ddp.py
+11
-0
tests/test_gates.py
tests/test_gates.py
+1
-11
tests/test_swipe.py
tests/test_swipe.py
+80
-0
No files found.
cuda/balancing.cu
View file @
c5cbd64b
#include <cstdio>
#include "balancing.cuh"
#include "global_exchange.h"
#include <torch/extension.h>
...
...
@@ -88,7 +89,9 @@ std::vector<torch::Tensor> _swipe_once(
ncclCommUserRank
(
smgr
->
ncclcomm
,
&
rank
);
cudaSetDevice
(
device_idx
);
auto
cap
=
capacity
.
item
<
long
>
();
auto
capacity_new
=
capacity
.
clone
();
auto
cap
=
capacity_new
.
item
<
long
>
();
// fprintf(stderr, "%d initial cap %ld ws %ld ne %ld\n", rank, cap, n_worker, n_expert);
long
batch_size
=
gate_idx
.
size
(
0
);
auto
gate_idx_cpu
=
gate_idx
.
cpu
();
...
...
@@ -98,16 +101,17 @@ std::vector<torch::Tensor> _swipe_once(
long
*
lec
=
new
long
[
n_worker
];
memset
(
lec
,
0
,
n_worker
*
sizeof
(
long
));
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
lec
[
gidx
[
i
]
%
n_expert
];
++
lec
[
gidx
[
i
]
/
n_expert
];
}
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
long
*
gec
=
_d2h
(
d_gec
,
n_expert
);
long
*
gec
=
_d2h
(
d_gec
,
n_worker
);
// fprintf(stderr, "%d initial ec, lec %ld %ld, gec %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1]);
/* Limit number of incoming samples */
long
*
drop_count
=
new
long
[
n_worker
];
memset
(
drop_count
,
0
,
n_worker
*
sizeof
(
long
));
for
(
long
i
=
0
;
i
<
n_
exp
er
t
;
++
i
)
{
for
(
long
i
=
0
;
i
<
n_
work
er
;
++
i
)
{
if
(
cap
>=
gec
[
i
])
{
drop_count
[
i
]
=
0
;
cap
-=
gec
[
i
];
...
...
@@ -118,10 +122,11 @@ std::vector<torch::Tensor> _swipe_once(
}
}
// fprintf(stderr, "%d before exchange cap %ld, drop count %ld %ld, lgec %ld %ld\n", rank, cap, drop_count[0], drop_count[1], gec[0], gec[1]);
/* Send limit information back */
_h2d
(
gec
,
d_gec
,
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_
exp
er
t
,
smgr
);
_d2h
(
d_lec
,
lec
,
n_
exp
er
t
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_
work
er
,
smgr
);
_d2h
(
d_lec
,
lec
,
n_
work
er
);
auto
d_dropcount
=
_h2d
(
drop_count
,
n_worker
);
ncclAllReduce
(
d_dropcount
,
d_dropcount
,
n_worker
,
ncclInt64
,
ncclSum
,
...
...
@@ -129,12 +134,14 @@ std::vector<torch::Tensor> _swipe_once(
_d2h
(
d_dropcount
,
drop_count
,
n_worker
);
auto
d_gcap
=
_cudamalloc
<
long
>
(
n_worker
);
_h2d
(
d_gcap
+
rank
,
&
cap
,
n_worker
);
_h2d
(
&
cap
,
d_gcap
+
rank
,
1
);
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
smgr
->
ncclcomm
,
smgr
->
stream
());
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
cudaDeviceSynchronize
();
/* Re-assign counts */
// fprintf(stderr, "%d exchange fin, drop count %ld %ld, nlec %ld %ld, gcap %ld %ld\n", rank, drop_count[0], drop_count[1], lec[0], lec[1], gcap[0], gcap[1]);
/* Re-assign and update counters */
for
(
long
i
=
0
,
j
=
0
;
i
<
n_worker
;
++
i
)
{
while
(
drop_count
[
i
]
>
0
)
{
if
(
drop_count
[
i
]
>
gcap
[
j
])
{
...
...
@@ -148,8 +155,9 @@ std::vector<torch::Tensor> _swipe_once(
}
}
}
// fprintf(stderr, "%d update done, lec %ld %ld, gec %ld %ld, gcap %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1], gcap[0], gcap[1]);
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
widx
=
gidx
[
i
]
%
n_expert
;
auto
widx
=
gidx
[
i
]
/
n_expert
;
if
(
lec
[
widx
]
>
0
)
{
--
lec
[
widx
];
}
else
{
...
...
@@ -162,9 +170,22 @@ std::vector<torch::Tensor> _swipe_once(
}
for
(;
lec
[
k
]
==
0
;
++
k
);
--
lec
[
gidx
[
i
]
=
k
*
n_expert
+
bias
];
// fprintf(stderr, "%d: assign %ld to %ld\n", rank, i, k);
}
*
capacity_new
.
data_ptr
<
long
>
()
=
cap
;
// fprintf(stderr, "%d all done\n", rank);
return
{
gate_idx_cpu
,
capacity
};
delete
[]
drop_count
;
delete
[]
lec
;
delete
[]
gec
;
delete
[]
gcap
;
cudaFree
(
d_dropcount
);
cudaFree
(
d_lec
);
cudaFree
(
d_gec
);
cudaFree
(
d_gcap
);
return
{
gate_idx_cpu
,
capacity_new
};
}
#undef UPDATE_COUNTERS
...
...
fmoe/gates/swipe_gate.py
View file @
c5cbd64b
...
...
@@ -13,22 +13,20 @@ import fmoe_cuda as fmoe_native
class
SwipeGate
(
NaiveGate
):
requires_moe_group
=
True
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
topk
=
2
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
)
def
swipe_once
(
self
,
idx
,
capacity
):
def
swipe_once
(
self
,
idx
,
capacity
,
bias
):
with
torch
.
no_grad
():
idx_new
,
capacity
=
fmoe_native
.
swipe_once
(
idx
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
self
.
num_expert
,
self
.
world_size
,
bias
)
idx_new
=
idx_new
.
to
(
idx
.
device
)
return
idx_new
,
capacity
def
forward
(
self
,
inp
):
score
=
self
.
gate
(
inp
)
_
,
orig_idx
=
torch
.
topk
(
gate_
score
,
k
=
self
.
top_k
,
dim
=-
1
)
_
,
orig_idx
=
torch
.
topk
(
score
,
k
=
self
.
top_k
,
dim
=-
1
)
if
not
self
.
training
:
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
...
...
@@ -38,10 +36,14 @@ class SwipeGate(NaiveGate):
dtype
=
torch
.
long
)
topk_idxs
=
[]
topk_vals
=
[]
idx_x
=
torch
.
arange
(
inp
.
shape
[
0
],
device
=
inp
.
device
)
for
k
in
range
(
self
.
top_k
):
idx
,
capacity
=
self
.
swipe_once
(
orig_idx
[:,
k
],
capacity
)
idx
,
capacity
=
self
.
swipe_once
(
orig_idx
[:,
k
],
capacity
,
k
%
self
.
num_expert
)
topk_vals
.
append
(
score
[
idx_x
,
idx
])
topk_idxs
.
append
(
idx
)
topk_idx
=
torch
.
stack
(
topk_idxs
).
transpose
(
0
,
1
)
topk_val
=
gate_score
[
idx_x
,
topk_idx
.
view
(
-
1
)].
view
(
-
1
,
self
.
top_k
)
topk_val
=
torch
.
stack
(
topk_vals
).
transpose
(
0
,
1
)
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
return
topk_idx
,
topk_val
fmoe/layers.py
View file @
c5cbd64b
...
...
@@ -128,9 +128,6 @@ class FMoE(nn.Module):
self
.
mask_dict
=
mask_dict
self
.
moe_group
=
moe_group
if
hasattr
(
self
.
gate
,
'requires_moe_group'
):
setattr
(
self
.
gate
,
'moe_gruop'
,
self
.
moe_group
)
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
"""
The default expert function which either calls the experts as a whole
...
...
tests/test_ddp.py
View file @
c5cbd64b
...
...
@@ -5,12 +5,23 @@ from typing import Dict
import
pytest
import
torch
import
torch.distributed
as
dist
from
test_numerical
import
test_fmoe
as
_test_fmoe
from
test_numerical
import
test_fmoe_linear
as
_test_fmoe_linear
from
test_numerical
import
_test_fmoe_local_ddp
def
_ensure_initialized
():
if
not
dist
.
is_initialized
():
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
):
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
"No enough GPU"
)
...
...
tests/test_gates.py
View file @
c5cbd64b
...
...
@@ -9,17 +9,7 @@ import torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
test_ddp
import
_run_distributed
def
_ensure_initialized
():
if
not
dist
.
is_initialized
():
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
from
test_ddp
import
_ensure_initialized
,
_run_distributed
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
...
...
tests/test_swipe.py
0 → 100644
View file @
c5cbd64b
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.functions
import
ensure_comm
from
fmoe.gates.swipe_gate
import
SwipeGate
from
test_ddp
import
_ensure_initialized
,
_run_distributed
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
def
test_swipe_gate
(
world_size
,
d_model
,
batch_size
,
n_expert
,
top_k
):
if
world_size
*
n_expert
<
2
:
pytest
.
skip
(
"No enough experts"
)
_run_distributed
(
'_test_swipe_gate'
,
world_size
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
,
'top_k'
:
top_k
},
script
=
__file__
)
def
_test_swipe_gate
(
d_model
,
batch_size
,
n_expert
,
top_k
):
_ensure_initialized
()
gate
=
SwipeGate
(
d_model
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
ensure_comm
(
x
,
None
)
topk_idx
,
topk_val
=
gate
(
x
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
def
test_swipe_once
(
world_size
,
batch_size
,
n_expert
):
if
world_size
*
n_expert
<
2
:
pytest
.
skip
(
"No enough experts"
)
_run_distributed
(
'_test_swipe_once'
,
world_size
,
{
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
},
script
=
__file__
)
def
_test_swipe_once
(
batch_size
,
n_expert
):
_ensure_initialized
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
gate
=
SwipeGate
(
4
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
idx
=
torch
.
randint
(
0
,
n_expert
*
world_size
,
(
batch_size
,)).
cuda
()
capacity
=
torch
.
scalar_tensor
(
batch_size
,
dtype
=
torch
.
long
)
ensure_comm
(
idx
,
None
)
sys
.
stderr
.
write
(
'{} Before swipe gate {}, capacity {}
\n
'
.
format
(
rank
,
idx
,
capacity
))
new_idx
,
new_cap
=
gate
.
swipe_once
(
idx
,
capacity
,
0
)
sys
.
stderr
.
write
(
'{} final gte {}, cap {}
\n
'
.
format
(
rank
,
new_idx
,
new_cap
))
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
# test_swipe_gate(8, 4, 8, 4, 2)
test_swipe_once
(
8
,
8
,
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment