Commit 27af1828 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

make cudaStreamManager thread local

parent f5cc759c
#include <cassert> #include <cassert>
#include <thread>
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
CudaStreamManager* smgr = NULL; thread_local CudaStreamManager* smgr = NULL;
CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) { CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) { if (!smgr) {
......
...@@ -11,8 +11,16 @@ ...@@ -11,8 +11,16 @@
class CudaStreamManager { class CudaStreamManager {
public: public:
CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) { CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) {
/*
Actually, we will see current_device == device,
which means pytorch always sets the correct device for us.
But for safety, we still manually set device to the desired one.
*/
int current_device;
checkCudaErrors(cudaGetDevice(&current_device));
printf("CudaStreamManager construnctor called, get device %d, set device %d\n", current_device, device);
checkCudaErrors(cudaSetDevice(device)); checkCudaErrors(cudaSetDevice(device));
printf("set device %d\n", device);
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle)); checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
......
...@@ -115,7 +115,7 @@ def test(): ...@@ -115,7 +115,7 @@ def test():
def test_dp(): def test_dp():
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
batch_size = 4 batch_size = 6
num_expert = 4 num_expert = 4
in_feat = 2 in_feat = 2
out_feat = 3 out_feat = 3
...@@ -125,14 +125,16 @@ def test_dp(): ...@@ -125,14 +125,16 @@ def test_dp():
print("data parallel of a nn.Linear model") print("data parallel of a nn.Linear model")
linear = nn.Linear(in_feat, in_feat).cuda() linear = nn.Linear(in_feat, in_feat).cuda()
moe_linear = torch.nn.DataParallel(linear, device_ids=[0, 1]) linear_dp = torch.nn.DataParallel(linear, device_ids=[0,1,2])
output = moe_linear(inp) output = linear_dp(inp)
print("successful!") print("successful!")
print("data parallel of our MoE model") print("data parallel of our MoE model")
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_dp = torch.nn.DataParallel(moe, device_ids=[0, 1]) moe_dp = torch.nn.DataParallel(moe, device_ids=[0,1,2])
output = moe_dp(inp, gate) for i in range(5):
print(i, "forward")
output = moe_dp(inp, gate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment