Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
864a4522
Commit
864a4522
authored
Jan 11, 2021
by
Rick Ho
Browse files
multi-gpu forward pass test
parent
069cf01a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
10 deletions
+20
-10
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+4
-2
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+3
-1
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+13
-7
No files found.
pytorch/cuda/moe.py
View file @
864a4522
...
@@ -27,15 +27,17 @@ class MOELayer(nn.Module):
...
@@ -27,15 +27,17 @@ class MOELayer(nn.Module):
class
MOELayer_raw
(
nn
.
Module
):
class
MOELayer_raw
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
world_size
=
0
):
super
(
MOELayer_raw
,
self
).
__init__
()
super
(
MOELayer_raw
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
torch
.
Tensor
(
num_expert
*
world_size
,
out_feat
,
in_feat
))
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
864a4522
...
@@ -155,6 +155,7 @@ void moe_cuda_global_scatter_impl(
...
@@ -155,6 +155,7 @@ void moe_cuda_global_scatter_impl(
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_scatter
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_scatter
(
...
@@ -224,6 +225,7 @@ void moe_cuda_global_gather_impl(
...
@@ -224,6 +225,7 @@ void moe_cuda_global_gather_impl(
NCCL_SAFE_CALL
(
ncclGroupEnd
());
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
smgr
->
sync
(
1
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_gather
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_gather
(
...
@@ -238,7 +240,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
...
@@ -238,7 +240,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES
(
output_buf
.
scalar_type
(),
"moe_cuda_global_gather"
,
([
&
]
{
"moe_cuda_global_gather"
,
([
&
]
{
moe_cuda_global_
sc
at
t
er_impl
<
scalar_t
>
(
moe_cuda_global_
g
at
h
er_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
...
...
pytorch/cuda/moe_test.py
View file @
864a4522
...
@@ -67,6 +67,7 @@ def test_module(moe, linear, inp, gate):
...
@@ -67,6 +67,7 @@ def test_module(moe, linear, inp, gate):
moe
.
zero_grad
()
moe
.
zero_grad
()
x
=
(
linear
(
inp
))
x
=
(
linear
(
inp
))
output
=
moe
(
x
,
gate
)
output
=
moe
(
x
,
gate
)
print
(
'ooutput'
,
torch
.
distributed
.
get_rank
(),
output
)
y
=
output
.
mean
()
y
=
output
.
mean
()
y
.
backward
()
y
.
backward
()
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
...
@@ -86,8 +87,14 @@ def test():
...
@@ -86,8 +87,14 @@ def test():
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
else
:
else
:
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
,
world_size
).
cuda
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
if
world_size
==
1
:
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
else
:
weight_array
=
[
torch
.
empty_like
(
moe
.
weight
.
data
).
cpu
()
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
weight_array
,
moe
.
weight
.
data
.
cpu
())
moe_raw
.
weight
.
data
=
torch
.
cat
(
weight_array
,
dim
=
0
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
gate
=
torch
.
randint
(
low
=
0
,
...
@@ -97,11 +104,12 @@ def test():
...
@@ -97,11 +104,12 @@ def test():
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
# gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
print
(
'hhh'
)
return
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
if
world_size
==
1
:
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
else
:
names
=
[
'Out'
]
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
print
(
'{} abs err {}'
.
format
(
name
,
err
))
...
@@ -134,8 +142,6 @@ def test_dp():
...
@@ -134,8 +142,6 @@ def test_dp():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
torch
.
distributed
.
init_process_group
(
backend
=
'mpi'
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
world_size
==
1
:
world_size
=
None
test
()
test
()
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
# print('{} / {}'.format(torch.distributed.get_rank(), torch.distributed.get_world_size()))
# perf()
# perf()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment