Commit ea66e5e5 authored by Rick Ho's avatar Rick Ho
Browse files

fix ensure device index bug

parent ae2c434e
......@@ -57,6 +57,10 @@ void CudaStreamManager::sync(int idx) {
}
void CudaStreamManager::setup(const int device) {
#ifdef MOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[SMGR_N_STREAMS];
handles = new cublasHandle_t[SMGR_N_STREAMS];
......
......@@ -29,7 +29,7 @@ public:
#endif
public:
CudaStreamManager(int device_): device(device_), ncclgood(0) {
CudaStreamManager(int device_): device(device_) {
this->setup(device);
}
......
......@@ -197,7 +197,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
}
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(0);
auto smgr = getCudaStreamManager(t.device().index());
smgr->ensure((void*)&p, t.device());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment