Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
046455a8
Commit
046455a8
authored
Jan 03, 2021
by
Jiezhong Qiu
Browse files
set device according to input
parent
2338a26e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
15 deletions
+26
-15
pytorch/cuda/cuda_stream_manager.cpp
pytorch/cuda/cuda_stream_manager.cpp
+3
-2
pytorch/cuda/cuda_stream_manager.h
pytorch/cuda/cuda_stream_manager.h
+4
-2
pytorch/cuda/moe_cuda_kernel.cu
pytorch/cuda/moe_cuda_kernel.cu
+16
-8
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+3
-3
No files found.
pytorch/cuda/cuda_stream_manager.cpp
View file @
046455a8
...
...
@@ -4,10 +4,11 @@
CudaStreamManager
*
smgr
=
NULL
;
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
)
{
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
,
const
int
device
)
{
if
(
!
smgr
)
{
smgr
=
new
CudaStreamManager
(
num_expert
);
smgr
=
new
CudaStreamManager
(
num_expert
,
device
);
}
assert
(
smgr
->
num_expert
==
num_expert
);
assert
(
smgr
->
device
==
device
);
return
smgr
;
}
pytorch/cuda/cuda_stream_manager.h
View file @
046455a8
...
...
@@ -8,7 +8,8 @@
class
CudaStreamManager
{
public:
CudaStreamManager
(
const
size_t
num_expert_
)
:
num_expert
(
num_expert_
)
{
CudaStreamManager
(
const
size_t
num_expert_
,
const
int
device_
)
:
num_expert
(
num_expert_
),
device
(
device_
)
{
checkCudaErrors
(
cudaSetDevice
(
device
));
streams
=
new
cudaStream_t
[
num_expert
];
checkCudaErrors
(
cublasCreate
(
&
handle
));
for
(
size_t
i
=
0
;
i
<
num_expert
;
++
i
)
{
...
...
@@ -22,10 +23,11 @@ public:
checkCudaErrors
(
cublasDestroy
(
handle
));
}
const
size_t
num_expert
;
const
int
device
;
cublasHandle_t
handle
;
cudaStream_t
*
streams
;
};
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
);
CudaStreamManager
*
getCudaStreamManager
(
const
size_t
num_expert
,
const
int
device
);
#endif // CUDA_STREAM_MANAGER
pytorch/cuda/moe_cuda_kernel.cu
View file @
046455a8
...
...
@@ -9,6 +9,7 @@
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <helper_cuda.h>
#include <c10/cuda/CUDAGuard.h>
// #include "timer.hh"
...
...
@@ -38,9 +39,10 @@ void moe_cuda_forward_impl(
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
cublasOperation_t
transb
)
{
cublasOperation_t
transb
,
const
int
device
)
{
auto
*
h
=
getCudaStreamManager
(
num_expert
);
auto
*
h
=
getCudaStreamManager
(
num_expert
,
device
);
checkCudaErrors
(
cublasSetStream
(
h
->
handle
,
*
(
h
->
streams
)));
...
...
@@ -95,9 +97,10 @@ void moe_cuda_grad_weight(
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
)
{
const
size_t
num_expert
,
const
int
device
)
{
auto
h
=
getCudaStreamManager
(
num_expert
);
auto
h
=
getCudaStreamManager
(
num_expert
,
device
);
int
*
gate_host
=
new
int
[
batch_size
];
scalar_t
alpha
=
1
,
beta
=
1
;
...
...
@@ -137,6 +140,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
#ifdef MOE_DEBUG
printf
(
"[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
int
device
=
device_of
(
input
).
value
().
index
();
auto
output
=
input
.
new_zeros
({
batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
...
...
@@ -149,7 +153,8 @@ std::vector<torch::Tensor> moe_cuda_forward(
in_feat
,
out_feat
,
num_expert
,
CUBLAS_OP_T
CUBLAS_OP_T
,
device
);
}));
...
...
@@ -166,10 +171,11 @@ std::vector<torch::Tensor> moe_cuda_backward(
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
int
device
=
device_of
(
input
).
value
().
index
();
auto
grad_input
=
grad_output
.
new_zeros
({
batch_size
,
in_feat
});
// batch_size x in_feat
auto
grad_weight
=
grad_output
.
new_zeros
({
num_expert
,
out_feat
,
in_feat
});
// num_expert x out_feat x in_feat
...
...
@@ -184,7 +190,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
out_feat
,
in_feat
,
num_expert
,
CUBLAS_OP_N
CUBLAS_OP_N
,
device
);
}));
...
...
@@ -197,7 +204,8 @@ std::vector<torch::Tensor> moe_cuda_backward(
batch_size
,
in_feat
,
out_feat
,
num_expert
num_expert
,
device
);
}));
...
...
pytorch/cuda/moe_test.py
View file @
046455a8
...
...
@@ -10,10 +10,10 @@ def perf():
out_feat
=
int
(
sys
.
argv
[
3
])
num_expert
=
int
(
sys
.
argv
[
4
])
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
(
"cuda:1"
)
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
(
"cuda:1"
)
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
(
"cuda:1"
)
o
=
moe
(
inp
,
gate
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment