Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3c42c892
Unverified
Commit
3c42c892
authored
May 20, 2021
by
Rick Ho
Committed by
GitHub
May 20, 2021
Browse files
Merge pull request #21 from bias_improvement
Bias improvement #15
parents
26824495
41cfe06c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
443 additions
and
365 deletions
+443
-365
cuda/moe.cpp
cuda/moe.cpp
+104
-97
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+298
-213
cuda/moe_cuda_kernel.h
cuda/moe_cuda_kernel.h
+4
-2
fmoe/functions.py
fmoe/functions.py
+11
-7
fmoe/layers.py
fmoe/layers.py
+1
-31
tests/test_ddp.py
tests/test_ddp.py
+5
-2
tests/test_numerical.py
tests/test_numerical.py
+20
-13
No files found.
cuda/moe.cpp
View file @
3c42c892
...
@@ -12,101 +12,108 @@
...
@@ -12,101 +12,108 @@
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
size_t
num_expert
)
{
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
return
moe_cuda_local_scatter
(
input
,
pos
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_local_gather
(
std
::
vector
<
torch
::
Tensor
>
moe_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
,
// [num_expert]
torch
::
Tensor
expert_count
// [batch_size]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
/*
The bias term should have been merged into weight. Note the following fact that
// check if bias is valid in case it exists
Wx+b = [W b] [x]
if
(
bias_o
.
has_value
())
{
[1]
auto
bias
=
bias_o
.
value
();
*/
CHECK_INPUT
(
bias
);
return
moe_cuda_forward
(
input_buf
,
weight
,
expert_count
);
}
return
moe_cuda_forward
(
input_buf
,
expert_count
,
weight
,
bias_o
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
,
// [num_expert]
torch
::
Tensor
expert_count
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
)
{
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
/*
The bias term should have been merged into weight. Note the following fact that
// check if bias is valid in case it exists
Wx+b = [W b] [x]
if
(
bias_o
.
has_value
())
{
[1]
auto
bias
=
bias_o
.
value
();
*/
CHECK_INPUT
(
bias
);
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
weight
,
expert_count
);
}
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
expert_count
,
weight
,
bias_o
);
}
}
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
std
::
vector
<
torch
::
Tensor
>
moe_expert_exchange
(
std
::
vector
<
torch
::
Tensor
>
moe_expert_exchange
(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
size_t
num_expert
,
size_t
n_workers
)
{
size_t
num_expert
,
size_t
n_workers
)
{
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_global_scatter
(
std
::
vector
<
torch
::
Tensor
>
moe_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
return
moe_cuda_global_scatter
(
input_buf
,
return
moe_cuda_global_scatter
(
input_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
batch_size
,
n_workers
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_global_gather
(
std
::
vector
<
torch
::
Tensor
>
moe_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
output_buf
);
CHECK_INPUT
(
output_buf
);
return
moe_cuda_global_gather
(
output_buf
,
return
moe_cuda_global_gather
(
output_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
batch_size
,
n_workers
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_global_fused_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
CHECK_INPUT
(
weight
);
return
moe_cuda_global_fused_forward
(
return
moe_cuda_global_fused_forward
(
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
global_batch_size
,
local_batch_size
,
n_workers
);
global_batch_size
,
local_batch_size
,
n_workers
);
}
}
#include <c10d/ProcessGroupNCCL.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
...
@@ -114,47 +121,47 @@ std::vector<torch::Tensor> moe_global_fused_forward(
...
@@ -114,47 +121,47 @@ std::vector<torch::Tensor> moe_global_fused_forward(
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
auto
key
=
std
::
to_string
(
dev
.
index
());
#ifdef ENABLE_NCCL_P2P_SUPPORT
#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId
ncclID
;
ncclUniqueId
ncclID
;
int
rank
=
getRank
();
int
rank
=
getRank
();
if
(
rank
==
0
)
{
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
ncclID
);
ncclGetUniqueId
(
&
ncclID
);
}
}
broadcastUniqueNCCLID
(
&
ncclID
,
broadcastUniqueNCCLID
(
&
ncclID
,
c10d
::
OpType
::
SEND
,
c10d
::
OpType
::
SEND
,
"fastmoe_nccl_comm"
,
"fastmoe_nccl_comm"
,
rank
);
rank
);
ncclComm_t
comm
;
ncclComm_t
comm
;
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
);
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
);
return
comm
;
return
comm
;
#else
#else
auto
v
=
getNCCLComm
(
key
,
{
dev
});
auto
v
=
getNCCLComm
(
key
,
{
dev
});
if
(
v
.
size
()
==
0
)
{
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
return
0
;
}
}
int
count
;
int
count
;
ncclCommCount
(
v
[
0
]
->
getNcclComm
(),
&
count
);
ncclCommCount
(
v
[
0
]
->
getNcclComm
(),
&
count
);
std
::
cerr
<<
"PyTorch has "
<<
v
.
size
()
<<
" comms, comm 0 size "
<<
count
<<
"
\n
"
;
std
::
cerr
<<
"PyTorch has "
<<
v
.
size
()
<<
" comms, comm 0 size "
<<
count
<<
"
\n
"
;
return
v
[
0
]
->
getNcclComm
();
return
v
[
0
]
->
getNcclComm
();
#endif
#endif
}
}
};
};
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
if
(
smgr
->
ncclgood
)
{
return
;
return
;
}
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
smgr
->
ncclgood
=
1
;
}
else
{
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
}
}
}
#endif // MOE_USE_NCCL
#endif // MOE_USE_NCCL
...
@@ -167,8 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -167,8 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
"MoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
#endif
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
...
...
cuda/moe_compute_kernel.cu
View file @
3c42c892
...
@@ -19,304 +19,386 @@
...
@@ -19,304 +19,386 @@
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
}
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
oubuf
[
i
]
=
inbuf
[
i
];
}
}
}
}
/*
This function is to be called with one block per each column
*/
template
<
typename
scalar_t
>
__global__
void
column_reduce
(
const
scalar_t
*
matrix
,
scalar_t
*
result
,
int
m
/* lines */
,
int
n
/* columns*/
)
{
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern
__shared__
unsigned
char
my_smem
[];
scalar_t
*
sdata
=
reinterpret_cast
<
scalar_t
*>
(
my_smem
);
// normal tid
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
// transposed tid for shared memory
int
new_tid
=
threadIdx
.
y
+
threadIdx
.
x
*
blockDim
.
y
;
// true x value in the matrix
int
real_x
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
i
=
real_x
+
n
*
threadIdx
.
y
;
const
int
it
=
n
*
blockDim
.
y
;
int
offset
=
it
;
float
accumulator
=
0
;
if
(
threadIdx
.
y
<
m
&&
real_x
<
n
)
{
// store all the values from this column in a warped way
accumulator
=
matrix
[
i
];
while
(
i
+
offset
<
n
*
m
)
{
accumulator
+=
matrix
[
i
+
offset
];
offset
+=
it
;
}
}
// save column reduction data in a transposed way
sdata
[
new_tid
]
=
accumulator
;
__syncthreads
();
for
(
size_t
t
=
16
;
t
>
0
;
t
>>=
1
)
{
if
(
tid
<
32
*
32
-
16
)
sdata
[
tid
]
+=
sdata
[
tid
+
t
];
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
&&
real_x
<
n
)
result
[
real_x
]
=
sdata
[
new_tid
];
}
void
moe_cuda_expert_count_impl
(
void
moe_cuda_expert_count_impl
(
const
int
*
d_gate
,
const
int
*
d_gate
,
int
*
expert_count
,
int
*
expert_count
,
int
*
d_pos
,
int
*
d_pos
,
const
size_t
num_expert
,
const
size_t
num_expert
,
const
size_t
batch_size
)
{
const
size_t
batch_size
)
{
int
*
gate
=
new
int
[
batch_size
];
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_ptr
=
new
int
[
num_expert
];
int
*
expert_ptr
=
new
int
[
num_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
++
expert_count
[
gate
[
i
]];
}
}
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
}
int
*
pos
=
new
int
[
batch_size
];
int
*
pos
=
new
int
[
batch_size
];
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
}
for
(
int
i
=
num_expert
-
1
;
i
>
0
;
--
i
)
{
for
(
int
i
=
num_expert
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
}
expert_ptr
[
0
]
=
0
;
expert_ptr
[
0
]
=
0
;
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
delete
[]
gate
;
delete
[]
gate
;
delete
[]
expert_ptr
;
delete
[]
expert_ptr
;
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
const
long
*
d_pos
,
const
long
*
d_pos
,
scalar_t
*
input_buf
,
scalar_t
*
input_buf
,
const
long
batch_size
,
const
long
batch_size
,
const
long
in_feat
,
const
long
in_feat
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
batch_scatter_kernel
<
scalar_t
>
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
input_buf
);
smgr
->
sync
(
1
);
smgr
->
sync
(
1
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
oubuf
[
i
]
=
inbuf
[
i
];
}
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_local_gather_impl
(
void
moe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
scalar_t
*
output_buf
,
const
long
*
d_pos
,
const
long
*
d_pos
,
scalar_t
*
output
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
out_feat
,
const
size_t
out_feat
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
batch_gather_kernel
<
scalar_t
>
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
output
);
smgr
->
sync
(
1
);
smgr
->
sync
(
1
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
void
moe_cuda_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
output_buf
,
scalar_t
*
output_buf
,
const
bool
has_bias
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
if
(
expert_count
[
i
]
==
0
)
{
continue
;
continue
;
}
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
smgr
->
handle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
out_feat
,
expert_count
[
i
],
in_feat
,
&
alpha
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
output_buf
+
out_feat
*
ptr
,
out_feat
));
));
ptr
+=
expert_count
[
i
];
ptr
+=
expert_count
[
i
];
}
}
smgr
->
sync
(
num_expert
);
smgr
->
sync
(
num_expert
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_backward_impl
(
void
moe_cuda_backward_impl
(
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_bias
,
const
bool
has_bias
,
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
// bias
if
(
expert_count
[
i
]
==
0
)
{
dim3
block_threads
(
32
,
32
);
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
dim3
grid_threads
(
out_feat
/
32
+
(
out_feat
%
32
?
1
:
0
),
1
);
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
continue
;
}
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
// Use T(B) x T(A) = T(C) to produce row-major C
if
(
expert_count
[
i
]
==
0
)
{
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
// Backward input: g_i = w @ g_o
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
checkCudaErrors
(
cublasXgemm
(
cudaMemset
(
grad_bias
+
i
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
out_feat
);
smgr
->
handle
(
i
),
continue
;
CUBLAS_OP_N
,
}
CUBLAS_OP_N
,
// Use T(B) x T(A) = T(C) to produce row-major C
in_feat
,
expert_count
[
i
],
out_feat
,
&
alpha
,
// Backward input: g_i = w @ g_o
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
checkCudaErrors
(
cublasXgemm
(
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
smgr
->
handle
(
i
),
&
beta
,
CUBLAS_OP_N
,
grad_input_buf
+
in_feat
*
ptr
,
in_feat
CUBLAS_OP_N
,
));
in_feat
,
expert_count
[
i
],
out_feat
,
&
alpha
,
// Backward weight: g_w = i @ g_o
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
checkCudaErrors
(
cublasXgemm
(
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
smgr
->
handle
(
i
),
&
beta
,
CUBLAS_OP_N
,
grad_input_buf
+
in_feat
*
ptr
,
in_feat
CUBLAS_OP_T
,
));
in_feat
,
out_feat
,
expert_count
[
i
],
&
alpha
,
// Backward weight: g_w = i @ g_o
input_buf
+
in_feat
*
ptr
,
in_feat
,
checkCudaErrors
(
cublasXgemm
(
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
smgr
->
handle
(
i
),
&
beta
,
CUBLAS_OP_N
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
CUBLAS_OP_T
,
));
in_feat
,
out_feat
,
expert_count
[
i
],
&
alpha
,
ptr
+=
expert_count
[
i
];
input_buf
+
in_feat
*
ptr
,
in_feat
,
}
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
smgr
->
sync
(
num_expert
);
&
beta
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
if
(
has_bias
)
{
column_reduce
<<<
grid_threads
,
block_threads
,
sizeof
(
scalar_t
)
*
1024
,
smgr
->
stream
(
0
)
>>>
(
grad_output_buf
+
ptr
*
out_feat
,
grad_bias
+
i
*
out_feat
,
expert_count
[
i
],
out_feat
);
}
ptr
+=
expert_count
[
i
];
}
smgr
->
sync
(
num_expert
);
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
size_t
num_expert
)
{
size_t
num_expert
)
{
const
auto
batch_size
=
gate
.
size
(
0
);
const
auto
batch_size
=
gate
.
size
(
0
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
pos_options
=
torch
::
TensorOptions
()
auto
pos_options
=
torch
::
TensorOptions
()
.
device
(
gate
.
device
())
.
device
(
gate
.
device
())
.
dtype
(
torch
::
kInt32
);
.
dtype
(
torch
::
kInt32
);
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
moe_cuda_expert_count_impl
(
moe_cuda_expert_count_impl
(
gate
.
data_ptr
<
int
>
(),
gate
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
int
>
(),
num_expert
,
num_expert
,
batch_size
);
batch_size
);
return
{
expert_count
,
pos
};
return
{
expert_count
,
pos
};
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
in_feat
=
input
.
size
(
1
);
const
auto
in_feat
=
input
.
size
(
1
);
auto
opt
=
torch
::
TensorOptions
()
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
());
.
device
(
input
.
device
());
auto
input_buf
=
torch
::
empty
({
batch_size
,
in_feat
},
opt
);
auto
input_buf
=
torch
::
empty
({
batch_size
,
in_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
pos
.
data_ptr
<
long
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
smgr
);
smgr
);
}));
}));
return
{
input_buf
,};
return
{
input_buf
,};
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
auto
opt
=
torch
::
TensorOptions
()
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
output_buf
.
dtype
())
.
dtype
(
output_buf
.
dtype
())
.
device
(
output_buf
.
device
());
.
device
(
output_buf
.
device
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
opt
);
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
pos
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
out_feat
,
out_feat
,
smgr
);
smgr
);
}));
}));
return
{
output
,};
return
{
output
,};
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
at
::
optional
<
torch
::
Tensor
>
bias
)
{
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
num_expert
,
in_feat
,
out_feat
);
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
torch
::
Tensor
output
;
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
if
(
bias
.
has_value
())
{
output
=
bias
.
value
().
repeat_interleave
(
expert_count
.
to
(
bias
.
value
().
device
()),
0
);
}
else
{
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
in_feat
,
bias
.
has_value
(),
out_feat
,
in_feat
,
num_expert
,
out_feat
,
smgr
num_expert
,
);
smgr
);
}));
}));
return
{
output
,
};
return
{
output
,
};
}
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
,
torch
::
Tensor
expert_count
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias
)
{
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
out_feat
=
weight
.
size
(
1
);
...
@@ -324,28 +406,31 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -324,28 +406,31 @@ std::vector<torch::Tensor> moe_cuda_backward(
#ifdef MOE_DEBUG
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld
\n
"
,
"out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
#endif
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_bias
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_backward_impl
<
scalar_t
>
(
moe_cuda_backward_impl
<
scalar_t
>
(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
expert_count
.
data_ptr
<
long
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_bias
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
out_feat
,
out_feat
,
num_expert
,
num_expert
,
smgr
smgr
);
);
}));
}));
return
{
grad_input_buf
,
grad_weight
};
return
{
grad_input_buf
,
grad_weight
,
grad_bias
};
}
}
cuda/moe_cuda_kernel.h
View file @
3c42c892
...
@@ -19,14 +19,16 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
...
@@ -19,14 +19,16 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
at
::
optional
<
torch
::
Tensor
>
bias
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
at
::
optional
<
torch
::
Tensor
>
bias
);
#ifdef MOE_USE_NCCL
#ifdef MOE_USE_NCCL
...
...
fmoe/functions.py
View file @
3c42c892
...
@@ -110,21 +110,25 @@ class MOELinear(Function):
...
@@ -110,21 +110,25 @@ class MOELinear(Function):
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
fwd_expert_count
,
weight
,
bias
=
None
):
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
)
variables
=
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
return
global_output_buf
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
(
input_buf
,
fwd_expert_count
,
weight
,
bias
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
grad_inp_buf
,
grad_weight
,
grad_bias
=
fmoe_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
grad_out
,
input_buf
,
fwd_expert_count
,
weight
,
bias
)
)
return
grad_inp_buf
,
grad_weight
,
None
if
not
torch
.
is_tensor
(
bias
):
grad_bias
=
None
return
grad_inp_buf
,
None
,
grad_weight
,
grad_bias
class
MOEGather
(
Function
):
class
MOEGather
(
Function
):
...
...
fmoe/layers.py
View file @
3c42c892
...
@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
...
@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r
"""
r
"""
Call MOE function
Call MOE function
"""
"""
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
fwd_expert_count
,
self
.
weight
,
self
.
bias
)
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x
=
x
+
bias
return
x
return
x
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
tests/test_ddp.py
View file @
3c42c892
...
@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
...
@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.FloatTensor'
,
'torch.DoubleTensor'
,
'torch.HalfTensor'
])
def
test_fmoe_linear_distributed
(
def
test_fmoe_linear_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
,
data_type
):
):
_run_distributed
(
_run_distributed
(
"_test_fmoe_linear"
,
"_test_fmoe_linear"
,
...
@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
...
@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model"
:
d_model
,
"d_model"
:
d_model
,
"d_hidden"
:
d_hidden
,
"d_hidden"
:
d_hidden
,
"mp_size"
:
mp_size
,
"mp_size"
:
mp_size
,
"data_type"
:
data_type
},
},
)
)
...
@@ -120,5 +122,6 @@ if __name__ == "__main__":
...
@@ -120,5 +122,6 @@ if __name__ == "__main__":
else
:
else
:
test_fmoe_local_ddp
(
mp_size
=
2
)
test_fmoe_local_ddp
(
mp_size
=
2
)
test_fmoe_linear_distributed
(
test_fmoe_linear_distributed
(
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
,
data_type
=
"torch.HalfTensor"
)
)
tests/test_numerical.py
View file @
3c42c892
...
@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
...
@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def
_perform_forward
(
def
_perform_forward
(
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
=
'torch.FloatTensor'
):
):
moe
.
zero_grad
()
moe
.
zero_grad
()
moe_raw
.
zero_grad
()
moe_raw
.
zero_grad
()
if
not
mp_group
:
inp
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
d_model
).
type
(
data_type
).
cuda
()
else
:
if
mp_group
:
group_sender
=
rank
//
mp_group
.
size
()
*
mp_group
.
size
()
group_sender
=
rank
//
mp_group
.
size
()
*
mp_group
.
size
()
inp
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
torch
.
distributed
.
broadcast
(
inp
,
group_sender
,
group
=
mp_group
)
torch
.
distributed
.
broadcast
(
inp
,
group_sender
,
group
=
mp_group
)
torch
.
distributed
.
broadcast
(
torch
.
distributed
.
broadcast
(
moe
.
gate
.
gate
.
weight
.
data
,
group_sender
,
group
=
mp_group
moe
.
gate
.
gate
.
weight
.
data
,
group_sender
,
group
=
mp_group
...
@@ -49,15 +49,17 @@ def _perform_forward(
...
@@ -49,15 +49,17 @@ def _perform_forward(
return
moe_out
,
raw_out
,
inp
.
grad
,
inp_raw
.
grad
return
moe_out
,
raw_out
,
inp
.
grad
,
inp_raw
.
grad
def
_assert_numer
c
ial
(
names
,
moe_out_list
,
raw_out_list
,
rank
):
def
_assert_numeri
c
al
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
if
err
>
1e-3
:
if
err
>
precision
:
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
mo
))
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
mo
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
raw out ==============
\n
"
)
sys
.
stderr
.
write
(
f
"===========
{
name
}
raw out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
ro
))
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
ro
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
diff ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
{}
\n
"
.
format
((
mo
-
ro
).
abs
(),
err
))
assert
False
assert
False
...
@@ -90,6 +92,7 @@ class MyMoE(FMoE):
...
@@ -90,6 +92,7 @@ class MyMoE(FMoE):
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.FloatTensor'
,
'torch.DoubleTensor'
,
'torch.HalfTensor'
])
def
test_fmoe_linear
(
def
test_fmoe_linear
(
num_expert
,
num_expert
,
top_k
,
top_k
,
...
@@ -101,6 +104,7 @@ def test_fmoe_linear(
...
@@ -101,6 +104,7 @@ def test_fmoe_linear(
mp_group
,
mp_group
,
dp_group
,
dp_group
,
world_group
,
world_group
,
data_type
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
...
@@ -108,7 +112,7 @@ def test_fmoe_linear(
...
@@ -108,7 +112,7 @@ def test_fmoe_linear(
moe
=
MyMoE
(
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
).
type
(
data_type
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
activation
=
activation
,
...
@@ -117,7 +121,7 @@ def test_fmoe_linear(
...
@@ -117,7 +121,7 @@ def test_fmoe_linear(
d_hidden
=
d_hidden
,
d_hidden
=
d_hidden
,
world_size
=
world_size
,
world_size
=
world_size
,
top_k
=
top_k
,
top_k
=
top_k
,
).
cuda
()
).
type
(
data_type
).
cuda
()
if
world_size
==
1
:
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
...
@@ -148,7 +152,7 @@ def test_fmoe_linear(
...
@@ -148,7 +152,7 @@ def test_fmoe_linear(
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
=
data_type
)
)
moe_out_list
=
(
moe_out_list
=
(
...
@@ -198,7 +202,10 @@ def test_fmoe_linear(
...
@@ -198,7 +202,10 @@ def test_fmoe_linear(
"h4toh bias grad"
,
"h4toh bias grad"
,
]
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
precision
=
5e-1
if
data_type
==
'torch.HalfTensor'
else
1e-3
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -299,7 +306,7 @@ def test_fmoe(
...
@@ -299,7 +306,7 @@ def test_fmoe(
raw_out_list
=
[
raw_out
,
raw_grad
,
raw_grad_in
]
raw_out_list
=
[
raw_out
,
raw_grad
,
raw_grad_in
]
names
=
[
"forward"
,
"backward"
,
"grad_in"
]
names
=
[
"forward"
,
"backward"
,
"grad_in"
]
_assert_numer
c
ial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numeri
c
al
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
class
MyModule
(
nn
.
Module
):
class
MyModule
(
nn
.
Module
):
...
@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
...
@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
names
=
[
"mp grad"
,
"dp grad"
,
"wp grad"
]
names
=
[
"mp grad"
,
"dp grad"
,
"wp grad"
]
_assert_numer
c
ial
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
_assert_numeri
c
al
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment