Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6cdb3cda
Commit
6cdb3cda
authored
Mar 24, 2021
by
TiagoMAntunes
Browse files
Fixed indentation (4 spaces now)
parent
303d0e93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
98 deletions
+98
-98
cuda/moe.cpp
cuda/moe.cpp
+98
-98
No files found.
cuda/moe.cpp
View file @
6cdb3cda
...
...
@@ -12,119 +12,119 @@
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
torch
::
Tensor
gate
,
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
}
void
merge_bias
(
torch
::
Tensor
&
input_buf
,
torch
::
Tensor
&
weight
,
at
::
optional
<
torch
::
Tensor
>
bias_o
)
{
torch
::
Tensor
bias
=
bias_o
.
value
();
weight
=
at
::
cat
({
weight
,
bias
.
unsqueeze
(
2
)},
2
);
// [W b]
auto
options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
ones
=
at
::
ones
(
input_buf
.
size
(
0
),
options
).
unsqueeze
(
1
);
input_buf
=
at
::
cat
({
input_buf
,
ones
},
1
);
// [X 1]
torch
::
Tensor
bias
=
bias_o
.
value
();
weight
=
at
::
cat
({
weight
,
bias
.
unsqueeze
(
2
)},
2
);
// [W b]
auto
options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
ones
=
at
::
ones
(
input_buf
.
size
(
0
),
options
).
unsqueeze
(
1
);
input_buf
=
at
::
cat
({
input_buf
,
ones
},
1
);
// [X 1]
}
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
expert_count
,
// [batch_size]
torch
::
Tensor
expert_count
,
// [batch_size]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
// Wx+b = [W b] [x]
// Wx+b = [W b] [x]
// [1]
if
(
bias_o
.
has_value
())
merge_bias
(
input_buf
,
weight
,
bias_o
);
if
(
bias_o
.
has_value
())
merge_bias
(
input_buf
,
weight
,
bias_o
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_forward
(
input_buf
,
expert_count
,
weight
);
return
moe_cuda_forward
(
input_buf
,
expert_count
,
weight
);
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
// Wx+b = [W b] [x]
// Wx+b = [W b] [x]
// [1]
if
(
bias_o
.
has_value
())
merge_bias
(
input_buf
,
weight
,
bias_o
);
if
(
bias_o
.
has_value
())
merge_bias
(
input_buf
,
weight
,
bias_o
);
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
expert_count
,
weight
,
bias_o
.
has_value
());
}
#ifdef MOE_USE_NCCL
std
::
vector
<
torch
::
Tensor
>
moe_expert_exchange
(
torch
::
Tensor
local_expert_count
,
size_t
num_expert
,
size_t
n_workers
)
{
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
torch
::
Tensor
local_expert_count
,
size_t
num_expert
,
size_t
n_workers
)
{
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
input_buf
);
return
moe_cuda_global_scatter
(
input_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
input_buf
);
return
moe_cuda_global_scatter
(
input_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_global_gather
(
output_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_global_gather
(
output_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_global_fused_forward
(
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
global_batch_size
,
local_batch_size
,
n_workers
);
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_global_fused_forward
(
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
global_batch_size
,
local_batch_size
,
n_workers
);
}
#include <c10d/ProcessGroupNCCL.hpp>
...
...
@@ -132,47 +132,47 @@ std::vector<torch::Tensor> moe_global_fused_forward(
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId
ncclID
;
int
rank
=
getRank
();
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
ncclID
);
}
broadcastUniqueNCCLID
(
&
ncclID
,
c10d
::
OpType
::
SEND
,
"fastmoe_nccl_comm"
,
rank
);
ncclComm_t
comm
;
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
);
return
comm
;
ncclUniqueId
ncclID
;
int
rank
=
getRank
();
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
ncclID
);
}
broadcastUniqueNCCLID
(
&
ncclID
,
c10d
::
OpType
::
SEND
,
"fastmoe_nccl_comm"
,
rank
);
ncclComm_t
comm
;
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
);
return
comm
;
#else
auto
v
=
getNCCLComm
(
key
,
{
dev
});
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
int
count
;
ncclCommCount
(
v
[
0
]
->
getNcclComm
(),
&
count
);
std
::
cerr
<<
"PyTorch has "
<<
v
.
size
()
<<
" comms, comm 0 size "
<<
count
<<
"
\n
"
;
return
v
[
0
]
->
getNcclComm
();
auto
v
=
getNCCLComm
(
key
,
{
dev
});
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
int
count
;
ncclCommCount
(
v
[
0
]
->
getNcclComm
(),
&
count
);
std
::
cerr
<<
"PyTorch has "
<<
v
.
size
()
<<
" comms, comm 0 size "
<<
count
<<
"
\n
"
;
return
v
[
0
]
->
getNcclComm
();
#endif
}
}
};
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
}
#endif // MOE_USE_NCCL
...
...
@@ -186,7 +186,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
"MoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment