Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
27d8beaa
Commit
27d8beaa
authored
Jan 09, 2021
by
Rick Ho
Browse files
make backward pass test
parent
92f1774a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
5 deletions
+14
-5
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+8
-1
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+1
-0
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+0
-1
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+5
-3
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
27d8beaa
...
@@ -4,6 +4,13 @@
...
@@ -4,6 +4,13 @@
#include "cuda_stream_manager.h"
#include "cuda_stream_manager.h"
cudaStream_t
CudaStreamManager
::
stream
(
size_t
idx
)
{
if
(
num_expert
<=
idx
)
{
this
->
setup
(
idx
+
1
);
}
return
this
->
streams
[
idx
];
}
void
CudaStreamManager
::
sync
(
int
i
)
{
void
CudaStreamManager
::
sync
(
int
i
)
{
if
(
i
>
-
1
)
{
if
(
i
>
-
1
)
{
cudaStreamSynchronize
(
streams
[
i
]);
cudaStreamSynchronize
(
streams
[
i
]);
...
@@ -28,7 +35,7 @@ void CudaStreamManager::setup(const size_t num_expert, const int device) {
...
@@ -28,7 +35,7 @@ void CudaStreamManager::setup(const size_t num_expert, const int device) {
streams
=
new
cudaStream_t
[
num_expert
];
streams
=
new
cudaStream_t
[
num_expert
];
handles
=
new
cublasHandle_t
[
num_expert
];
handles
=
new
cublasHandle_t
[
num_expert
];
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cudaStreamCreate
(
streams
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
checkCudaErrors
(
cublasCreate
(
handles
+
i
));
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
cublasSetStream
(
handles
[
i
],
streams
[
i
]);
}
}
...
...
pytorch/cuda/cuda_stream_manager.h
View file @
27d8beaa
...
@@ -25,6 +25,7 @@ public:
...
@@ -25,6 +25,7 @@ public:
}
}
void
setup
(
const
size_t
num_expert
,
const
int
device
=-
1
);
void
setup
(
const
size_t
num_expert
,
const
int
device
=-
1
);
cudaStream_t
stream
(
size_t
=
0
);
~
CudaStreamManager
()
{
~
CudaStreamManager
()
{
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
...
...
pytorch/cuda/moe.py
View file @
27d8beaa
...
@@ -108,7 +108,6 @@ def test():
...
@@ -108,7 +108,6 @@ def test():
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
]
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
print
(
'{} abs err {}'
.
format
(
name
,
err
))
...
...
pytorch/cuda/moe_cuda_kernel.cu
View file @
27d8beaa
...
@@ -21,7 +21,9 @@
...
@@ -21,7 +21,9 @@
// #define MOE_BREAKDOWN
// #define MOE_BREAKDOWN
// #define MOE_DEBUG
// #define MOE_DEBUG
thread_local
CudaStreamManager
smgr
;
// thread_local CudaStreamManager smgr;
// TODO: handle stream manager faults with torch threads
CudaStreamManager
smgr
;
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
...
@@ -87,7 +89,7 @@ void moe_cuda_local_scatter_impl(
...
@@ -87,7 +89,7 @@ void moe_cuda_local_scatter_impl(
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
)
{
const
size_t
in_feat
)
{
batch_scatter_kernel
<
scalar_t
>
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
stream
s
[
0
]
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
smgr
.
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
input_buf
);
smgr
.
sync
(
0
);
smgr
.
sync
(
0
);
}
}
...
@@ -111,7 +113,7 @@ void moe_cuda_local_gather_impl(
...
@@ -111,7 +113,7 @@ void moe_cuda_local_gather_impl(
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
out_feat
)
{
const
size_t
out_feat
)
{
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
.
stream
s
[
0
]
>>>
(
out_feat
,
d_pos
,
output_buf
,
<<<
batch_size
,
256
,
0
,
smgr
.
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
output
);
smgr
.
sync
(
0
);
smgr
.
sync
(
0
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment