Commit ea66e5e5 authored by Rick Ho's avatar Rick Ho
Browse files

fix ensure device index bug

parent ae2c434e
...@@ -57,6 +57,10 @@ void CudaStreamManager::sync(int idx) { ...@@ -57,6 +57,10 @@ void CudaStreamManager::sync(int idx) {
} }
void CudaStreamManager::setup(const int device) { void CudaStreamManager::setup(const int device) {
#ifdef MOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(cudaSetDevice(device)); checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS]; streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS]; handles = new cublasHandle_t[SMGR_N_STREAMS];
......
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
#endif #endif
public: public:
CudaStreamManager(int device_): device(device_), ncclgood(0) { CudaStreamManager(int device_): device(device_) {
this->setup(device); this->setup(device);
} }
......
...@@ -197,7 +197,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather( ...@@ -197,7 +197,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
} }
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) { void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(0); auto smgr = getCudaStreamManager(t.device().index());
smgr->ensure((void*)&p, t.device()); smgr->ensure((void*)&p, t.device());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment