Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
94b68a3d
Commit
94b68a3d
authored
Jan 17, 2021
by
Rick Ho
Browse files
overlap allreduce with computation
parent
7c3e5149
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
9 deletions
+154
-9
pytorch/cuda/moe.cpp
pytorch/cuda/moe.cpp
+16
-0
pytorch/cuda/moe_cuda_kernel.h
pytorch/cuda/moe_cuda_kernel.h
+7
-0
pytorch/cuda/moe_function.py
pytorch/cuda/moe_function.py
+4
-8
pytorch/cuda/moe_fused_kernel.cu
pytorch/cuda/moe_fused_kernel.cu
+127
-1
No files found.
pytorch/cuda/moe.cpp
View file @
94b68a3d
...
...
@@ -95,6 +95,20 @@ std::vector<torch::Tensor> moe_global_gather(
batch_size
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_global_fused_forward
(
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
global_batch_size
,
local_batch_size
,
n_workers
);
}
#endif
/*
...
...
@@ -116,6 +130,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"backward"
,
&
moe_backward
,
"MoE backward (CUDA)"
);
...
...
pytorch/cuda/moe_cuda_kernel.h
View file @
94b68a3d
...
...
@@ -45,6 +45,13 @@ std::vector<torch::Tensor> moe_cuda_expert_exchange(
torch
::
Tensor
local_expert_count
,
long
num_expert
,
long
n_workers
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
);
#endif
#endif // MOE_CUDA_KERNEL_H
pytorch/cuda/moe_function.py
View file @
94b68a3d
...
...
@@ -40,16 +40,12 @@ class MOEGlobal(Function):
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
moe_cuda
.
local_scatter
(
inp
,
pos
)
global_input_buf
,
=
moe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
global_output_buf
,
=
moe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
)
local_output_buf
,
=
moe_cuda
.
global_gather
(
global_output_buf
,
local_output_buf
,
global_input_buf
=
moe_cuda
.
global_fused_forward
(
local_input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
)
output
,
=
moe_cuda
.
local_gather
(
local_output_buf
,
pos
)
variables
=
(
global_input_buf
,
gate
,
weight
,
...
...
pytorch/cuda/moe_fused_kernel.cu
View file @
94b68a3d
...
...
@@ -11,12 +11,138 @@
#include <c10/cuda/CUDAGuard.h>
#include "cuda_stream_manager.h"
#include "cublas_wrapper.h"
#ifdef MOE_USE_NCCL
#include <mpi.h>
#include <nccl.h>
// TODO
template
<
typename
scalar_t
>
void
moe_cuda_global_fused_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
scalar_t
*
global_input_buf
,
scalar_t
*
global_output_buf
,
scalar_t
*
output_buf
,
const
int
*
local_expert_count
,
const
int
*
global_expert_count
,
long
in_feat
,
long
out_feat
,
long
num_expert
,
long
world_size
,
CudaStreamManager
*
smgr
)
{
int
ptr
=
0
;
int
send_ptr
=
0
;
int
recv_ptr
=
0
;
int
*
expert_ptr
=
new
int
[
num_expert
*
world_size
];
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
*
world_size
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
local_expert_count
[
i
-
1
];
}
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
;
i
<
num_expert
;
++
i
)
{
int
expert_count
=
0
;
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
}
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
global_input_buf
+
recv_ptr
*
in_feat
,
global_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
recv_ptr
+=
global_expert_count
[
idx
];
expert_count
+=
global_expert_count
[
idx
];
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
,
in_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
global_input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
global_output_buf
+
out_feat
*
ptr
,
out_feat
));
ptr
+=
expert_count
;
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
j
=
0
;
j
<
world_size
;
++
j
)
{
int
idx
=
i
+
j
*
num_expert
;
if
(
global_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
global_output_buf
+
send_ptr
*
out_feat
,
global_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
send_ptr
+=
global_expert_count
[
idx
];
}
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
i
)));
}
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
delete
[]
expert_ptr
;
smgr
->
sync
(
num_expert
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
global_input_buf
=
input_buf
.
new_empty
({
global_batch_size
,
in_feat
});
auto
global_output_buf
=
input_buf
.
new_empty
({
global_batch_size
,
out_feat
});
auto
output_buf
=
input_buf
.
new_empty
({
local_batch_size
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES
(
input_buf
.
scalar_type
(),
"moe_cuda_global_fused_forward"
,
([
&
]
{
moe_cuda_global_fused_forward_impl
(
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
global_input_buf
.
data_ptr
<
scalar_t
>
(),
global_output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
local_expert_count
.
data_ptr
<
int
>
(),
global_expert_count
.
data_ptr
<
int
>
(),
in_feat
,
out_feat
,
num_expert
,
n_workers
,
smgr
);
}));
return
{
output_buf
,
global_input_buf
};
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment