Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
79ccb7b6
Commit
79ccb7b6
authored
Feb 04, 2021
by
Rick Ho
Browse files
fix pytorch header compilation bug
parent
15f98a10
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
30 additions
and
95 deletions
+30
-95
cuda/cuda_stream_manager.cpp
cuda/cuda_stream_manager.cpp
+0
-32
cuda/cuda_stream_manager.h
cuda/cuda_stream_manager.h
+0
-1
cuda/moe.cpp
cuda/moe.cpp
+30
-11
cuda/moe_comm_kernel.cu
cuda/moe_comm_kernel.cu
+0
-5
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+0
-43
cuda/moe_cuda_kernel.h
cuda/moe_cuda_kernel.h
+0
-3
No files found.
cuda/cuda_stream_manager.cpp
View file @
79ccb7b6
...
...
@@ -4,43 +4,11 @@
#include <thread>
#include <iostream>
#ifdef MOE_USE_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
#endif // MOE_USE_NCCL
#include "cuda_stream_manager.h"
#include <helper_cuda.h>
#define SMGR_N_STREAMS 16
#ifdef MOE_USE_NCCL
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
auto
v
=
getNCCLComm
(
key
,
{
dev
},
c10d
::
OpType
::
ALLTOALL
);
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
return
v
[
0
]
->
getNcclComm
();
}
};
void
CudaStreamManager
::
ensure
(
void
*
torchp
,
at
::
Device
dev
)
{
if
(
this
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)
torchp
;
this
->
ncclcomm
=
h
->
getcomm
(
dev
);
if
(
this
->
ncclcomm
!=
0
)
{
this
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
}
#endif // MOE_USE_NCCL
cudaStream_t
CudaStreamManager
::
stream
(
size_t
idx
)
{
return
this
->
streams
[
idx
%
SMGR_N_STREAMS
];
}
...
...
cuda/cuda_stream_manager.h
View file @
79ccb7b6
...
...
@@ -25,7 +25,6 @@ public:
#ifdef MOE_USE_NCCL
char
ncclgood
;
ncclComm_t
ncclcomm
;
void
ensure
(
void
*
,
class
at
::
Device
);
#endif
public:
...
...
cuda/moe.cpp
View file @
79ccb7b6
...
...
@@ -109,18 +109,37 @@ std::vector<torch::Tensor> moe_global_fused_forward(
global_batch_size
,
local_batch_size
,
n_workers
);
}
#endif
/*
int main() {
int device=2;
torch::Tensor input = torch::randn({2048, 512}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
torch::Tensor gate = torch::zeros({2048, 2}, torch::dtype(torch::kInt64));
torch::Tensor weight = torch::randn({2, 512, 2048}, torch::dtype(torch::kFloat32).device(torch::kCUDA, device));
checkCudaErrors(cudaSetDevice(device));
moe_cuda_forward(input, gate, weight);
#include <c10d/ProcessGroupNCCL.hpp>
#include "cuda_stream_manager.h"
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
auto
v
=
getNCCLComm
(
key
,
{
dev
},
c10d
::
OpType
::
ALLTOALL
);
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
return
v
[
0
]
->
getNcclComm
();
}
};
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
}
*/
#endif // MOE_USE_NCCL
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"expert_count"
,
&
moe_expert_count
,
"MoE expert count (CUDA)"
);
...
...
cuda/moe_comm_kernel.cu
View file @
79ccb7b6
...
...
@@ -195,9 +195,4 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
return
{
local_output_buf
,};
}
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
smgr
->
ensure
((
void
*
)
&
p
,
t
.
device
());
}
#endif
cuda/moe_compute_kernel.cu
View file @
79ccb7b6
...
...
@@ -350,46 +350,3 @@ std::vector<torch::Tensor> moe_cuda_backward(
return
{
grad_input_buf
,
grad_weight
};
}
/*
int main() {
typedef float data_t;
size_t batch_size = 4096;
size_t top_k = 2;
size_t num_expert = 128;
size_t in_feat = 1024;
size_t out_feat = 4096;
data_t *input, *weight;
data_t *output;
size_t *gate;
checkCudaErrors(cudaMalloc(&input, batch_size * in_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&weight, num_expert * in_feat * out_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&output, batch_size * top_k * out_feat * sizeof(data_t)));
checkCudaErrors(cudaMalloc(&gate, batch_size * top_k * sizeof(size_t)));
size_t nt = 16;
double tsum = 0, tmax = 0;
size_t *gate_host = new size_t[batch_size * top_k];
for (size_t i=0; i<batch_size * top_k; ++i) {
gate_host[i] = rand() % num_expert;
}
checkCudaErrors(cudaMemcpy(gate, gate_host, batch_size * top_k * sizeof(size_t), cudaMemcpyHostToDevice));
moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
for (size_t i=0; i<nt; ++i) {
timestamp(start);
moe_first_linear_cuda_forward<data_t>(input, gate, weight, output, batch_size, top_k, in_feat, out_feat);
timestamp(end);
auto t = getDuration(start, end);
tsum += t;
if (t > tmax) tmax = t;
}
printf("Mean %.3lf us, max %.3lf us\n", tsum / nt * 1e6, tmax * 1e6);
double tflops = (double)batch_size * top_k * in_feat * out_feat * nt * 2e-12 / tsum;
printf("%.3lf TFLOPs\n", tflops);
}
*/
cuda/moe_cuda_kernel.h
View file @
79ccb7b6
...
...
@@ -41,9 +41,6 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
torch
::
Tensor
global_expert_count
,
long
batch_size
,
long
n_workers
);
#include <c10d/ProcessGroupNCCL.hpp>
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
,
torch
::
Tensor
t
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment