Commit 27d8beaa authored by Rick Ho's avatar Rick Ho
Browse files

make backward pass test

parent 92f1774a
......@@ -4,6 +4,13 @@
#include "cuda_stream_manager.h"
cudaStream_t CudaStreamManager::stream(size_t idx) {
if (num_expert <= idx) {
this->setup(idx + 1);
}
return this->streams[idx];
}
void CudaStreamManager::sync(int i) {
if (i > -1) {
cudaStreamSynchronize(streams[i]);
......@@ -28,7 +35,7 @@ void CudaStreamManager::setup(const size_t num_expert, const int device) {
streams = new cudaStream_t[num_expert];
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasCreate(handles + i));
cublasSetStream(handles[i], streams[i]);
}
......
......@@ -25,6 +25,7 @@ public:
}
void setup(const size_t num_expert, const int device=-1);
cudaStream_t stream(size_t=0);
~CudaStreamManager() {
#ifdef MOE_DEBUG
......
......@@ -108,7 +108,6 @@ def test():
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
names = ['Out']
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
......
......@@ -21,7 +21,9 @@
// #define MOE_BREAKDOWN
// #define MOE_DEBUG
thread_local CudaStreamManager smgr;
// thread_local CudaStreamManager smgr;
// TODO: handle stream manager faults with torch threads
CudaStreamManager smgr;
template <typename scalar_t>
__global__
......@@ -87,7 +89,7 @@ void moe_cuda_local_scatter_impl(
const size_t batch_size,
const size_t in_feat) {
batch_scatter_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(in_feat, d_pos, input,
<<<batch_size, 256, 0, smgr.stream(0)>>>(in_feat, d_pos, input,
input_buf);
smgr.sync(0);
}
......@@ -111,7 +113,7 @@ void moe_cuda_local_gather_impl(
const size_t batch_size,
const size_t out_feat) {
batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, smgr.streams[0]>>>(out_feat, d_pos, output_buf,
<<<batch_size, 256, 0, smgr.stream(0)>>>(out_feat, d_pos, output_buf,
output);
smgr.sync(0);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment