Commit 046455a8 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

set device according to input

parent 2338a26e
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
CudaStreamManager* smgr = NULL; CudaStreamManager* smgr = NULL;
CudaStreamManager* getCudaStreamManager(const size_t num_expert) { CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device) {
if (!smgr) { if (!smgr) {
smgr = new CudaStreamManager(num_expert); smgr = new CudaStreamManager(num_expert, device);
} }
assert(smgr->num_expert == num_expert); assert(smgr->num_expert == num_expert);
assert(smgr->device == device);
return smgr; return smgr;
} }
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
class CudaStreamManager { class CudaStreamManager {
public: public:
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) { CudaStreamManager(const size_t num_expert_, const int device_) : num_expert(num_expert_), device(device_) {
checkCudaErrors(cudaSetDevice(device));
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle)); checkCudaErrors(cublasCreate(&handle));
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
...@@ -22,10 +23,11 @@ public: ...@@ -22,10 +23,11 @@ public:
checkCudaErrors(cublasDestroy(handle)); checkCudaErrors(cublasDestroy(handle));
} }
const size_t num_expert; const size_t num_expert;
const int device;
cublasHandle_t handle; cublasHandle_t handle;
cudaStream_t* streams; cudaStream_t* streams;
}; };
CudaStreamManager* getCudaStreamManager(const size_t num_expert); CudaStreamManager* getCudaStreamManager(const size_t num_expert, const int device);
#endif // CUDA_STREAM_MANAGER #endif // CUDA_STREAM_MANAGER
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
// #include "timer.hh" // #include "timer.hh"
...@@ -38,9 +39,10 @@ void moe_cuda_forward_impl( ...@@ -38,9 +39,10 @@ void moe_cuda_forward_impl(
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert,
cublasOperation_t transb) { cublasOperation_t transb,
const int device) {
auto* h = getCudaStreamManager(num_expert); auto* h = getCudaStreamManager(num_expert, device);
checkCudaErrors(cublasSetStream(h->handle, *(h->streams))); checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
...@@ -95,9 +97,10 @@ void moe_cuda_grad_weight( ...@@ -95,9 +97,10 @@ void moe_cuda_grad_weight(
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert) { const size_t num_expert,
const int device) {
auto h = getCudaStreamManager(num_expert); auto h = getCudaStreamManager(num_expert, device);
int* gate_host = new int[batch_size]; int* gate_host = new int[batch_size];
scalar_t alpha = 1, beta = 1; scalar_t alpha = 1, beta = 1;
...@@ -137,6 +140,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -137,6 +140,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif #endif
int device = device_of(input).value().index();
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
...@@ -149,7 +153,8 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -149,7 +153,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
in_feat, in_feat,
out_feat, out_feat,
num_expert, num_expert,
CUBLAS_OP_T CUBLAS_OP_T,
device
); );
})); }));
...@@ -166,10 +171,11 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -166,10 +171,11 @@ std::vector<torch::Tensor> moe_cuda_backward(
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif #endif
int device = device_of(input).value().index();
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
...@@ -184,7 +190,8 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -184,7 +190,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
out_feat, out_feat,
in_feat, in_feat,
num_expert, num_expert,
CUBLAS_OP_N CUBLAS_OP_N,
device
); );
})); }));
...@@ -197,7 +204,8 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -197,7 +204,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size, batch_size,
in_feat, in_feat,
out_feat, out_feat,
num_expert num_expert,
device
); );
})); }));
......
...@@ -10,10 +10,10 @@ def perf(): ...@@ -10,10 +10,10 @@ def perf():
out_feat = int(sys.argv[3]) out_feat = int(sys.argv[3])
num_expert = int(sys.argv[4]) num_expert = int(sys.argv[4])
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda("cuda:1")
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda("cuda:1")
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda("cuda:1")
o = moe(inp, gate) o = moe(inp, gate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment