Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
8a56481b
Commit
8a56481b
authored
Oct 26, 2021
by
Rick Ho
Browse files
swipe pass test
parent
c5cbd64b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
15 deletions
+7
-15
cuda/balancing.cu
cuda/balancing.cu
+2
-9
tests/test_swipe.py
tests/test_swipe.py
+5
-6
No files found.
cuda/balancing.cu
View file @
8a56481b
...
@@ -91,7 +91,6 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -91,7 +91,6 @@ std::vector<torch::Tensor> _swipe_once(
auto
capacity_new
=
capacity
.
clone
();
auto
capacity_new
=
capacity
.
clone
();
auto
cap
=
capacity_new
.
item
<
long
>
();
auto
cap
=
capacity_new
.
item
<
long
>
();
// fprintf(stderr, "%d initial cap %ld ws %ld ne %ld\n", rank, cap, n_worker, n_expert);
long
batch_size
=
gate_idx
.
size
(
0
);
long
batch_size
=
gate_idx
.
size
(
0
);
auto
gate_idx_cpu
=
gate_idx
.
cpu
();
auto
gate_idx_cpu
=
gate_idx
.
cpu
();
...
@@ -106,7 +105,6 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -106,7 +105,6 @@ std::vector<torch::Tensor> _swipe_once(
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
long
*
gec
=
_d2h
(
d_gec
,
n_worker
);
long
*
gec
=
_d2h
(
d_gec
,
n_worker
);
// fprintf(stderr, "%d initial ec, lec %ld %ld, gec %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1]);
/* Limit number of incoming samples */
/* Limit number of incoming samples */
long
*
drop_count
=
new
long
[
n_worker
];
long
*
drop_count
=
new
long
[
n_worker
];
...
@@ -122,7 +120,6 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -122,7 +120,6 @@ std::vector<torch::Tensor> _swipe_once(
}
}
}
}
// fprintf(stderr, "%d before exchange cap %ld, drop count %ld %ld, lgec %ld %ld\n", rank, cap, drop_count[0], drop_count[1], gec[0], gec[1]);
/* Send limit information back */
/* Send limit information back */
_h2d
(
gec
,
d_gec
,
n_worker
);
_h2d
(
gec
,
d_gec
,
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_worker
,
smgr
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_worker
,
smgr
);
...
@@ -138,9 +135,7 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -138,9 +135,7 @@ std::vector<torch::Tensor> _swipe_once(
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
smgr
->
ncclcomm
,
smgr
->
stream
());
smgr
->
ncclcomm
,
smgr
->
stream
());
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
cudaDeviceSynchronize
();
// fprintf(stderr, "%d exchange fin, drop count %ld %ld, nlec %ld %ld, gcap %ld %ld\n", rank, drop_count[0], drop_count[1], lec[0], lec[1], gcap[0], gcap[1]);
/* Re-assign and update counters */
/* Re-assign and update counters */
for
(
long
i
=
0
,
j
=
0
;
i
<
n_worker
;
++
i
)
{
for
(
long
i
=
0
,
j
=
0
;
i
<
n_worker
;
++
i
)
{
while
(
drop_count
[
i
]
>
0
)
{
while
(
drop_count
[
i
]
>
0
)
{
...
@@ -155,7 +150,6 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -155,7 +150,6 @@ std::vector<torch::Tensor> _swipe_once(
}
}
}
}
}
}
// fprintf(stderr, "%d update done, lec %ld %ld, gec %ld %ld, gcap %ld %ld\n", rank, lec[0], lec[1], gec[0], gec[1], gcap[0], gcap[1]);
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
widx
=
gidx
[
i
]
/
n_expert
;
auto
widx
=
gidx
[
i
]
/
n_expert
;
if
(
lec
[
widx
]
>
0
)
{
if
(
lec
[
widx
]
>
0
)
{
...
@@ -169,11 +163,10 @@ std::vector<torch::Tensor> _swipe_once(
...
@@ -169,11 +163,10 @@ std::vector<torch::Tensor> _swipe_once(
continue
;
continue
;
}
}
for
(;
lec
[
k
]
==
0
;
++
k
);
for
(;
lec
[
k
]
==
0
;
++
k
);
--
lec
[
gidx
[
i
]
=
k
*
n_expert
+
bias
];
--
lec
[
k
];
// fprintf(stderr, "%d: assign %ld to %ld\n", rank, i, k)
;
gidx
[
i
]
=
k
*
n_expert
+
bias
;
}
}
*
capacity_new
.
data_ptr
<
long
>
()
=
cap
;
*
capacity_new
.
data_ptr
<
long
>
()
=
cap
;
// fprintf(stderr, "%d all done\n", rank);
delete
[]
drop_count
;
delete
[]
drop_count
;
delete
[]
lec
;
delete
[]
lec
;
...
...
tests/test_swipe.py
View file @
8a56481b
...
@@ -42,7 +42,6 @@ def _test_swipe_gate(d_model, batch_size, n_expert, top_k):
...
@@ -42,7 +42,6 @@ def _test_swipe_gate(d_model, batch_size, n_expert, top_k):
topk_idx
,
topk_val
=
gate
(
x
)
topk_idx
,
topk_val
=
gate
(
x
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
...
@@ -65,16 +64,16 @@ def _test_swipe_once(batch_size, n_expert):
...
@@ -65,16 +64,16 @@ def _test_swipe_once(batch_size, n_expert):
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
gate
=
SwipeGate
(
4
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
gate
=
SwipeGate
(
4
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
idx
=
torch
.
randint
(
0
,
n_expert
*
world_size
,
(
batch_size
,)).
cuda
()
idx
=
torch
.
randint
(
0
,
n_expert
*
world_size
,
(
batch_size
,)).
cuda
()
capacity
=
torch
.
scalar_tensor
(
batch_size
,
dtype
=
torch
.
long
)
capacity
=
torch
.
scalar_tensor
(
batch_size
*
2
,
dtype
=
torch
.
long
)
ensure_comm
(
idx
,
None
)
ensure_comm
(
idx
,
None
)
sys
.
stderr
.
write
(
'{} Before swipe gate {}, capacity {}
\n
'
.
format
(
rank
,
idx
,
capacity
))
new_idx
,
new_cap
=
gate
.
swipe_once
(
idx
,
capacity
,
0
)
new_idx
,
new_cap
=
gate
.
swipe_once
(
idx
,
capacity
,
0
)
sys
.
stderr
.
write
(
'{} final gte {}, cap {}
\n
'
.
format
(
rank
,
new_idx
,
new_cap
))
idx
=
torch
.
randint
(
0
,
n_expert
*
world_size
,
(
batch_size
,)).
cuda
()
new_idx
,
new_cap
=
gate
.
swipe_once
(
idx
,
new_cap
,
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
#
test_swipe_gate(8, 4, 8, 4, 2)
test_swipe_gate
(
8
,
4
,
8
,
4
,
2
)
test_swipe_once
(
8
,
8
,
4
)
#
test_swipe_once(8, 8
00
, 4)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment