Commit 27d8beaa authored by Rick Ho's avatar Rick Ho
Browse files

make backward pass test

parent 92f1774a
...@@ -4,6 +4,13 @@ ...@@ -4,6 +4,13 @@
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
cudaStream_t CudaStreamManager::stream(size_t idx) {
if (num_expert <= idx) {
this->setup(idx + 1);
}
return this->streams[idx];
}
void CudaStreamManager::sync(int i) { void CudaStreamManager::sync(int i) {
if (i > -1) { if (i > -1) {
cudaStreamSynchronize(streams[i]); cudaStreamSynchronize(streams[i]);
...@@ -28,7 +35,7 @@ void CudaStreamManager::setup(const size_t num_expert, const int device) { ...@@ -28,7 +35,7 @@ void CudaStreamManager::setup(const size_t num_expert, const int device) {
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert]; handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i)); checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i)); checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]); cublasSetStream(handles[i], streams[i]);
} }
......
...@@ -25,6 +25,7 @@ public: ...@@ -25,6 +25,7 @@ public:
} }
void setup(const size_t num_expert, const int device=-1); void setup(const size_t num_expert, const int device=-1);
cudaStream_t stream(size_t=0);
~CudaStreamManager() { ~CudaStreamManager() {
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
......
...@@ -108,7 +108,6 @@ def test(): ...@@ -108,7 +108,6 @@ def test():
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
names = ['Out']
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err)) print('{} abs err {}'.format(name, err))
......
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
// #define MOE_BREAKDOWN // #define MOE_BREAKDOWN
// #define MOE_DEBUG // #define MOE_DEBUG
thread_local CudaStreamManager smgr; // thread_local CudaStreamManager smgr;
// TODO: handle stream manager faults with torch threads
CudaStreamManager smgr;
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
...@@ -87,7 +89,7 @@ void moe_cuda_local_scatter_impl( ...@@ -87,7 +89,7 @@ void moe_cuda_local_scatter_impl(
const size_t batch_size, const size_t batch_size,
const size_t in_feat) { const size_t in_feat) {
batch_scatter_kernel<scalar_t> batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input, <<<batch_size, 256, 0, smgr.stream(0)>>>(in_feat, d_pos, input,
input_buf); input_buf);
smgr.sync(0); smgr.sync(0);
} }
...@@ -111,7 +113,7 @@ void moe_cuda_local_gather_impl( ...@@ -111,7 +113,7 @@ void moe_cuda_local_gather_impl(
const size_t batch_size, const size_t batch_size,
const size_t out_feat) { const size_t out_feat) {
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf, <<<batch_size, 256, 0, smgr.stream(0)>>>(out_feat, d_pos, output_buf,
output); output);
smgr.sync(0); smgr.sync(0);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment