Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
72e9bc9e
Commit
72e9bc9e
authored
Jan 29, 2021
by
Rick Ho
Browse files
scatter gather kernel support for non-equal shapes and fix tests
parent
49c97411
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
14 deletions
+25
-14
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+10
-4
tests/moe_test.py
tests/moe_test.py
+15
-10
No files found.
cuda/moe_compute_kernel.cu
View file @
72e9bc9e
...
@@ -228,10 +228,13 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
...
@@ -228,10 +228,13 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
const
auto
batch_size
=
input
.
size
(
0
);
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
in_feat
=
input
.
size
(
1
);
const
auto
in_feat
=
input
.
size
(
1
);
auto
input_buf
=
torch
::
empty_like
(
input
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
());
auto
input_buf
=
torch
::
empty
({
batch_size
,
in_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
([
&
]
{
([
&
]
{
...
@@ -250,10 +253,13 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
...
@@ -250,10 +253,13 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
torch
::
Tensor
output_buf
,
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
const
auto
batch_size
=
output_buf
.
size
(
0
);
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
auto
output
=
torch
::
empty_like
(
output_buf
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
output_buf
.
dtype
())
.
device
(
output_buf
.
device
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
([
&
]
{
([
&
]
{
...
...
tests/moe_test.py
View file @
72e9bc9e
...
@@ -88,9 +88,13 @@ def test_module(moe, linear, inp, gate):
...
@@ -88,9 +88,13 @@ def test_module(moe, linear, inp, gate):
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
rank
=
None
world_size
=
None
def
test
():
def
test
():
torch
.
manual_seed
(
42
+
torch
.
distributed
.
get_
rank
()
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
torch
.
distributed
.
get_
rank
()
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
batch_size
=
4
batch_size
=
4
num_expert
=
2
num_expert
=
2
in_feat
=
6
in_feat
=
6
...
@@ -123,13 +127,10 @@ def test():
...
@@ -123,13 +127,10 @@ def test():
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
if
world_size
>
1
:
if
world_size
>
1
:
rank
=
torch
.
distributed
.
get_rank
()
ou
,
wg
,
lwg
,
lbg
=
raw_out
ou
,
wg
,
lwg
,
lbg
=
raw_out
torch
.
distributed
.
all_reduce
(
wg
)
torch
.
distributed
.
all_reduce
(
wg
)
wg
=
wg
[
rank
*
num_expert
:(
rank
+
1
)
*
num_expert
]
wg
=
wg
[
rank
*
num_expert
:(
rank
+
1
)
*
num_expert
]
raw_out
=
ou
,
wg
,
lwg
,
lbg
raw_out
=
ou
,
wg
,
lwg
,
lbg
else
:
rank
=
0
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'Rank {} {} abs err {}'
.
format
(
rank
,
name
,
err
))
print
(
'Rank {} {} abs err {}'
.
format
(
rank
,
name
,
err
))
...
@@ -166,11 +167,15 @@ def test_dp():
...
@@ -166,11 +167,15 @@ def test_dp():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
os
.
environ
[
'RANK'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_RANK'
,
0
)
os
.
environ
[
'RANK'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_RANK'
,
'0'
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_SIZE'
,
1
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_SIZE'
,
'1'
)
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
if
int
(
os
.
environ
[
'WORLD_SIZE'
])
>
1
:
rank
=
torch
.
distributed
.
get_rank
()
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
world_size
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
if
len
(
sys
.
argv
)
>=
2
:
if
len
(
sys
.
argv
)
>=
2
:
task
=
sys
.
argv
[
1
]
task
=
sys
.
argv
[
1
]
print
(
'Specificed task {}'
.
format
(
task
))
print
(
'Specificed task {}'
.
format
(
task
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment