Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3c42c892
Unverified
Commit
3c42c892
authored
May 20, 2021
by
Rick Ho
Committed by
GitHub
May 20, 2021
Browse files
Merge pull request #21 from bias_improvement
Bias improvement #15
parents
26824495
41cfe06c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
443 additions
and
365 deletions
+443
-365
cuda/moe.cpp
cuda/moe.cpp
+104
-97
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+298
-213
cuda/moe_cuda_kernel.h
cuda/moe_cuda_kernel.h
+4
-2
fmoe/functions.py
fmoe/functions.py
+11
-7
fmoe/layers.py
fmoe/layers.py
+1
-31
tests/test_ddp.py
tests/test_ddp.py
+5
-2
tests/test_numerical.py
tests/test_numerical.py
+20
-13
No files found.
cuda/moe.cpp
View file @
3c42c892
...
...
@@ -12,101 +12,108 @@
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
moe_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
torch
::
Tensor
gate
,
size_t
num_expert
)
{
CHECK_INPUT
(
gate
);
return
moe_cuda_expert_count
(
gate
,
num_expert
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
input
);
return
moe_cuda_local_scatter
(
input
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_local_gather
(
output_buf
,
pos
);
}
std
::
vector
<
torch
::
Tensor
>
moe_forward
(
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
// [batch_size]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
expert_count
,
// [num_expert]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return
moe_cuda_forward
(
input_buf
,
weight
,
expert_count
);
// check if bias is valid in case it exists
if
(
bias_o
.
has_value
())
{
auto
bias
=
bias_o
.
value
();
CHECK_INPUT
(
bias
);
}
return
moe_cuda_forward
(
input_buf
,
expert_count
,
weight
,
bias_o
);
}
std
::
vector
<
torch
::
Tensor
>
moe_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x in_feat]
torch
::
Tensor
expert_count
,
// [num_expert]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias_o
// [num_expert x out_feat] or None
)
{
CHECK_INPUT
(
grad_output_buf
);
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
/*
The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x]
[1]
*/
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
weight
,
expert_count
);
// check if bias is valid in case it exists
if
(
bias_o
.
has_value
())
{
auto
bias
=
bias_o
.
value
();
CHECK_INPUT
(
bias
);
}
return
moe_cuda_backward
(
grad_output_buf
,
input_buf
,
expert_count
,
weight
,
bias_o
);
}
#ifdef MOE_USE_NCCL
std
::
vector
<
torch
::
Tensor
>
moe_expert_exchange
(
torch
::
Tensor
local_expert_count
,
size_t
num_expert
,
size_t
n_workers
)
{
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
torch
::
Tensor
local_expert_count
,
size_t
num_expert
,
size_t
n_workers
)
{
return
moe_cuda_expert_exchange
(
local_expert_count
,
num_expert
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_scatter
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
input_buf
);
return
moe_cuda_global_scatter
(
input_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
input_buf
);
return
moe_cuda_global_scatter
(
input_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_global_gather
(
output_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
torch
::
Tensor
output_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
size_t
batch_size
,
size_t
n_workers
)
{
CHECK_INPUT
(
output_buf
);
return
moe_cuda_global_gather
(
output_buf
,
local_expert_count
,
global_expert_count
,
batch_size
,
n_workers
);
}
std
::
vector
<
torch
::
Tensor
>
moe_global_fused_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
weight
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_global_fused_forward
(
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
global_batch_size
,
local_batch_size
,
n_workers
);
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
long
global_batch_size
,
long
local_batch_size
,
long
n_workers
)
{
CHECK_INPUT
(
input_buf
);
CHECK_INPUT
(
weight
);
return
moe_cuda_global_fused_forward
(
input_buf
,
weight
,
local_expert_count
,
global_expert_count
,
global_batch_size
,
local_batch_size
,
n_workers
);
}
#include <c10d/ProcessGroupNCCL.hpp>
...
...
@@ -114,47 +121,47 @@ std::vector<torch::Tensor> moe_global_fused_forward(
class
HackNCCLGroup
:
public
c10d
::
ProcessGroupNCCL
{
public:
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
ncclComm_t
getcomm
(
at
::
Device
dev
)
{
auto
key
=
std
::
to_string
(
dev
.
index
());
#ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId
ncclID
;
int
rank
=
getRank
();
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
ncclID
);
}
broadcastUniqueNCCLID
(
&
ncclID
,
c10d
::
OpType
::
SEND
,
"fastmoe_nccl_comm"
,
rank
);
ncclComm_t
comm
;
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
);
return
comm
;
ncclUniqueId
ncclID
;
int
rank
=
getRank
();
if
(
rank
==
0
)
{
ncclGetUniqueId
(
&
ncclID
);
}
broadcastUniqueNCCLID
(
&
ncclID
,
c10d
::
OpType
::
SEND
,
"fastmoe_nccl_comm"
,
rank
);
ncclComm_t
comm
;
ncclCommInitRank
(
&
comm
,
getSize
(),
ncclID
,
rank
);
return
comm
;
#else
auto
v
=
getNCCLComm
(
key
,
{
dev
});
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
int
count
;
ncclCommCount
(
v
[
0
]
->
getNcclComm
(),
&
count
);
std
::
cerr
<<
"PyTorch has "
<<
v
.
size
()
<<
" comms, comm 0 size "
<<
count
<<
"
\n
"
;
return
v
[
0
]
->
getNcclComm
();
auto
v
=
getNCCLComm
(
key
,
{
dev
});
if
(
v
.
size
()
==
0
)
{
std
::
cerr
<<
"PyTorch has nothing
\n
"
;
return
0
;
}
int
count
;
ncclCommCount
(
v
[
0
]
->
getNcclComm
(),
&
count
);
std
::
cerr
<<
"PyTorch has "
<<
v
.
size
()
<<
" comms, comm 0 size "
<<
count
<<
"
\n
"
;
return
v
[
0
]
->
getNcclComm
();
#endif
}
}
};
void
moe_ensure_nccl
(
c10d
::
ProcessGroupNCCL
&
p
,
torch
::
Tensor
t
)
{
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
auto
smgr
=
getCudaStreamManager
(
t
.
device
().
index
());
if
(
smgr
->
ncclgood
)
{
return
;
}
HackNCCLGroup
*
h
=
(
HackNCCLGroup
*
)(
void
*
)
&
p
;
smgr
->
ncclcomm
=
h
->
getcomm
(
t
.
device
());
if
(
smgr
->
ncclcomm
!=
0
)
{
smgr
->
ncclgood
=
1
;
}
else
{
std
::
cerr
<<
"Nccl initialization failed
\n
"
;
}
}
#endif // MOE_USE_NCCL
...
...
@@ -167,8 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"expert_exchange"
,
&
moe_expert_exchange
,
"MoE expert exchange (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
moe_global_scatter
,
"MoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
moe_global_gather
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
m
.
def
(
"global_fused_forward"
,
&
moe_global_fused_forward
,
"MoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
moe_ensure_nccl
,
"MoE ensure torch nccl comm"
);
#endif
m
.
def
(
"forward"
,
&
moe_forward
,
"MoE forward (CUDA)"
);
...
...
cuda/moe_compute_kernel.cu
View file @
3c42c892
...
...
@@ -19,304 +19,386 @@
template
<
typename
scalar_t
>
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
}
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
}
}
template
<
typename
scalar_t
>
__global__
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
/*
This function is to be called with one block per each column
*/
template
<
typename
scalar_t
>
__global__
void
column_reduce
(
const
scalar_t
*
matrix
,
scalar_t
*
result
,
int
m
/* lines */
,
int
n
/* columns*/
)
{
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern
__shared__
unsigned
char
my_smem
[];
scalar_t
*
sdata
=
reinterpret_cast
<
scalar_t
*>
(
my_smem
);
// normal tid
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
// transposed tid for shared memory
int
new_tid
=
threadIdx
.
y
+
threadIdx
.
x
*
blockDim
.
y
;
// true x value in the matrix
int
real_x
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
i
=
real_x
+
n
*
threadIdx
.
y
;
const
int
it
=
n
*
blockDim
.
y
;
int
offset
=
it
;
float
accumulator
=
0
;
if
(
threadIdx
.
y
<
m
&&
real_x
<
n
)
{
// store all the values from this column in a warped way
accumulator
=
matrix
[
i
];
while
(
i
+
offset
<
n
*
m
)
{
accumulator
+=
matrix
[
i
+
offset
];
offset
+=
it
;
}
}
// save column reduction data in a transposed way
sdata
[
new_tid
]
=
accumulator
;
__syncthreads
();
for
(
size_t
t
=
16
;
t
>
0
;
t
>>=
1
)
{
if
(
tid
<
32
*
32
-
16
)
sdata
[
tid
]
+=
sdata
[
tid
+
t
];
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
&&
real_x
<
n
)
result
[
real_x
]
=
sdata
[
new_tid
];
}
void
moe_cuda_expert_count_impl
(
const
int
*
d_gate
,
int
*
expert_count
,
int
*
d_pos
,
const
size_t
num_expert
,
int
*
expert_count
,
int
*
d_pos
,
const
size_t
num_expert
,
const
size_t
batch_size
)
{
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_ptr
=
new
int
[
num_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
}
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
int
*
pos
=
new
int
[
batch_size
];
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
for
(
int
i
=
num_expert
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
expert_ptr
[
0
]
=
0
;
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
delete
[]
gate
;
delete
[]
expert_ptr
;
int
*
expert_ptr
=
new
int
[
num_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
}
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
int
*
pos
=
new
int
[
batch_size
];
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
for
(
int
i
=
num_expert
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
expert_ptr
[
0
]
=
0
;
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
delete
[]
gate
;
delete
[]
expert_ptr
;
}
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
long
*
d_pos
,
scalar_t
*
input_buf
,
const
long
batch_size
,
const
long
in_feat
,
CudaStreamManager
*
smgr
)
{
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
smgr
->
sync
(
1
);
const
long
*
d_pos
,
scalar_t
*
input_buf
,
const
long
batch_size
,
const
long
in_feat
,
CudaStreamManager
*
smgr
)
{
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
smgr
->
sync
(
1
);
}
template
<
typename
scalar_t
>
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
void
moe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
long
*
d_pos
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
out_feat
,
CudaStreamManager
*
smgr
)
{
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
smgr
->
sync
(
1
);
const
long
*
d_pos
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
out_feat
,
CudaStreamManager
*
smgr
)
{
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
smgr
->
sync
(
1
);
}
template
<
typename
scalar_t
>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
output_buf
,
const
bool
has_bias
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
continue
;
}
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
));
ptr
+=
expert_count
[
i
];
}
smgr
->
sync
(
num_expert
);
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
has_bias
?
1
:
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
continue
;
}
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_feat
,
expert_count
[
i
],
in_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
input_buf
+
ptr
*
in_feat
,
in_feat
,
&
beta
,
output_buf
+
out_feat
*
ptr
,
out_feat
));
ptr
+=
expert_count
[
i
];
}
smgr
->
sync
(
num_expert
);
}
template
<
typename
scalar_t
>
void
moe_cuda_backward_impl
(
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
const
scalar_t
*
weight
,
const
long
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_bias
,
const
bool
has_bias
,
const
size_t
batch_size
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
num_expert
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
scalar_t
alpha
=
1
,
beta
=
0
;
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
continue
;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_feat
,
expert_count
[
i
],
out_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
grad_input_buf
+
in_feat
*
ptr
,
in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_feat
,
out_feat
,
expert_count
[
i
],
&
alpha
,
input_buf
+
in_feat
*
ptr
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
ptr
+=
expert_count
[
i
];
}
smgr
->
sync
(
num_expert
);
// bias
dim3
block_threads
(
32
,
32
);
dim3
grid_threads
(
out_feat
/
32
+
(
out_feat
%
32
?
1
:
0
),
1
);
for
(
int
i
=
0
,
ptr
=
0
;
i
<
num_expert
;
++
i
)
{
if
(
expert_count
[
i
]
==
0
)
{
cudaMemset
(
grad_weight
+
i
*
in_feat
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
in_feat
*
out_feat
);
cudaMemset
(
grad_bias
+
i
*
out_feat
,
0
,
sizeof
(
scalar_t
)
*
out_feat
);
continue
;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_feat
,
expert_count
[
i
],
out_feat
,
&
alpha
,
weight
+
i
*
in_feat
*
out_feat
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
grad_input_buf
+
in_feat
*
ptr
,
in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors
(
cublasXgemm
(
smgr
->
handle
(
i
),
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_feat
,
out_feat
,
expert_count
[
i
],
&
alpha
,
input_buf
+
in_feat
*
ptr
,
in_feat
,
grad_output_buf
+
ptr
*
out_feat
,
out_feat
,
&
beta
,
grad_weight
+
i
*
in_feat
*
out_feat
,
in_feat
));
if
(
has_bias
)
{
column_reduce
<<<
grid_threads
,
block_threads
,
sizeof
(
scalar_t
)
*
1024
,
smgr
->
stream
(
0
)
>>>
(
grad_output_buf
+
ptr
*
out_feat
,
grad_bias
+
i
*
out_feat
,
expert_count
[
i
],
out_feat
);
}
ptr
+=
expert_count
[
i
];
}
smgr
->
sync
(
num_expert
);
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
)
{
const
auto
batch_size
=
gate
.
size
(
0
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
pos_options
=
torch
::
TensorOptions
()
.
device
(
gate
.
device
())
.
dtype
(
torch
::
kInt32
);
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
moe_cuda_expert_count_impl
(
gate
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
int
>
(),
num_expert
,
batch_size
);
return
{
expert_count
,
pos
};
torch
::
Tensor
gate
,
size_t
num_expert
)
{
const
auto
batch_size
=
gate
.
size
(
0
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
pos_options
=
torch
::
TensorOptions
()
.
device
(
gate
.
device
())
.
dtype
(
torch
::
kInt32
);
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
moe_cuda_expert_count_impl
(
gate
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
int
>
(),
num_expert
,
batch_size
);
return
{
expert_count
,
pos
};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
in_feat
=
input
.
size
(
1
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
());
auto
input_buf
=
torch
::
empty
({
batch_size
,
in_feat
},
opt
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
());
auto
input_buf
=
torch
::
empty
({
batch_size
,
in_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"moe_local_scatter_cuda"
,
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
smgr
);
}));
return
{
input_buf
,};
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
smgr
);
}));
return
{
input_buf
,};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
output_buf
.
dtype
())
.
device
(
output_buf
.
device
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
opt
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
output_buf
.
dtype
())
.
device
(
output_buf
.
device
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"moe_local_gather_cuda"
,
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
smgr
);
}));
return
{
output
,};
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
smgr
);
}));
return
{
output
,};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
at
::
optional
<
torch
::
Tensor
>
bias
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
const
auto
in_feat
=
weight
.
size
(
2
);
#ifdef MOE_DEBUG
printf
(
"[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld
\n
"
,
num_expert
,
in_feat
,
out_feat
);
num_expert
,
in_feat
,
out_feat
);
#endif
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
torch
::
Tensor
output
;
if
(
bias
.
has_value
())
{
output
=
bias
.
value
().
repeat_interleave
(
expert_count
.
to
(
bias
.
value
().
device
()),
0
);
}
else
{
auto
out_options
=
torch
::
TensorOptions
()
.
device
(
input_buf
.
device
())
.
dtype
(
input_buf
.
dtype
());
output
=
torch
::
empty
({
batch_size
,
out_feat
},
out_options
);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_forward_cuda"
,
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
in_feat
,
out_feat
,
num_expert
,
smgr
);
([
&
]
{
moe_cuda_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
in_feat
,
out_feat
,
num_expert
,
smgr
);
}));
return
{
output
,
};
}
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
torch
::
Tensor
expert_count
torch
::
Tensor
grad_output_buf
,
// [batch_size x out_feat]
torch
::
Tensor
input_buf
,
// [batch_size x out_feat]
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
// [num_expert x out_feat x in_feat]
at
::
optional
<
torch
::
Tensor
>
bias
)
{
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
const
auto
batch_size
=
input_buf
.
size
(
0
);
const
auto
num_expert
=
weight
.
size
(
0
);
const
auto
out_feat
=
weight
.
size
(
1
);
...
...
@@ -324,28 +406,31 @@ std::vector<torch::Tensor> moe_cuda_backward(
#ifdef MOE_DEBUG
printf
(
"[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
"out_feat (d_ffn)=%ld
\n
"
,
batch_size
,
num_expert
,
in_feat
,
out_feat
);
#endif
auto
grad_input_buf
=
grad_output_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
grad_weight
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
,
in_feat
});
auto
grad_bias
=
grad_output_buf
.
new_empty
({
num_expert
,
out_feat
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"moe_cuda_backward"
,
([
&
]
{
moe_cuda_backward_impl
<
scalar_t
>
(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
long
>
(),
expert_count
.
data_ptr
<
long
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_bias
.
data_ptr
<
scalar_t
>
(),
bias
.
has_value
(),
batch_size
,
in_feat
,
out_feat
,
num_expert
,
smgr
smgr
);
}));
return
{
grad_input_buf
,
grad_weight
};
return
{
grad_input_buf
,
grad_weight
,
grad_bias
};
}
cuda/moe_cuda_kernel.h
View file @
3c42c892
...
...
@@ -19,14 +19,16 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
std
::
vector
<
torch
::
Tensor
>
moe_cuda_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
at
::
optional
<
torch
::
Tensor
>
bias
);
std
::
vector
<
torch
::
Tensor
>
moe_cuda_backward
(
torch
::
Tensor
grad_output_buf
,
torch
::
Tensor
input_buf
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
weight
,
torch
::
Tensor
expert_count
);
at
::
optional
<
torch
::
Tensor
>
bias
);
#ifdef MOE_USE_NCCL
...
...
fmoe/functions.py
View file @
3c42c892
...
...
@@ -110,21 +110,25 @@ class MOELinear(Function):
"""
@
staticmethod
def
forward
(
ctx
,
global_input_buf
,
weight
,
fwd_expert_count
):
def
forward
(
ctx
,
global_input_buf
,
fwd_expert_count
,
weight
,
bias
=
None
):
(
global_output_buf
,)
=
fmoe_cuda
.
forward
(
global_input_buf
,
weight
,
fwd_expert_count
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
variables
=
(
global_input_buf
,
weight
,
fwd_expert_count
)
variables
=
(
global_input_buf
,
fwd_expert_count
,
weight
,
bias
)
ctx
.
save_for_backward
(
*
variables
)
return
global_output_buf
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
input_buf
,
weight
,
fwd_expert_count
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
=
fmoe_cuda
.
backward
(
grad_out
,
input_buf
,
weight
,
fwd_expert_count
(
input_buf
,
fwd_expert_count
,
weight
,
bias
)
=
ctx
.
saved_tensors
grad_inp_buf
,
grad_weight
,
grad_bias
=
fmoe_cuda
.
backward
(
grad_out
,
input_buf
,
fwd_expert_count
,
weight
,
bias
)
return
grad_inp_buf
,
grad_weight
,
None
if
not
torch
.
is_tensor
(
bias
):
grad_bias
=
None
return
grad_inp_buf
,
None
,
grad_weight
,
grad_bias
class
MOEGather
(
Function
):
...
...
fmoe/layers.py
View file @
3c42c892
...
...
@@ -41,37 +41,7 @@ class FMoELinear(nn.Module):
r
"""
Call MOE function
"""
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
# TODO: torch.repeat_interleave seems have numerical
# instability in backward, leading to incorrect
# gradient computation for solution 1 and 2.
# Solution 3 uses a for-loop to expand the bias,
# but is 50% slower.
# This part should finally goes to MOELinear.apply,
# like MOELinear.apply(x, weight, bias, count)
# Solution 1
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
# Solution 2
# bias_idx = torch.arange(self.num_expert)\
# .repeat_interleave(fwd_expert_count)
# bias = self.bias[bias_idx]
# Solution 3
# bias = []
# for i in range(self.num_expert):
# if fwd_expert_count[i] > 0:
# bias.append(
# self.bias[i].unsqueeze(0).expand(
# fwd_expert_count[i], -1
# )
# )
# bias = torch.cat(bias, dim=0)
x
=
x
+
bias
x
=
MOELinear
.
apply
(
inp
,
fwd_expert_count
,
self
.
weight
,
self
.
bias
)
return
x
def
extra_repr
(
self
)
->
str
:
...
...
tests/test_ddp.py
View file @
3c42c892
...
...
@@ -41,8 +41,9 @@ def _run_distributed(func, world_size, args: Dict):
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.FloatTensor'
,
'torch.DoubleTensor'
,
'torch.HalfTensor'
])
def
test_fmoe_linear_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
mp_size
,
data_type
):
_run_distributed
(
"_test_fmoe_linear"
,
...
...
@@ -54,6 +55,7 @@ def test_fmoe_linear_distributed(
"d_model"
:
d_model
,
"d_hidden"
:
d_hidden
,
"mp_size"
:
mp_size
,
"data_type"
:
data_type
},
)
...
...
@@ -120,5 +122,6 @@ if __name__ == "__main__":
else
:
test_fmoe_local_ddp
(
mp_size
=
2
)
test_fmoe_linear_distributed
(
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
,
data_type
=
"torch.HalfTensor"
)
tests/test_numerical.py
View file @
3c42c892
...
...
@@ -17,15 +17,15 @@ from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
def
_perform_forward
(
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
=
'torch.FloatTensor'
):
moe
.
zero_grad
()
moe_raw
.
zero_grad
()
if
not
mp_group
:
inp
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
else
:
inp
=
torch
.
rand
(
batch_size
,
d_model
).
type
(
data_type
).
cuda
()
if
mp_group
:
group_sender
=
rank
//
mp_group
.
size
()
*
mp_group
.
size
()
inp
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
torch
.
distributed
.
broadcast
(
inp
,
group_sender
,
group
=
mp_group
)
torch
.
distributed
.
broadcast
(
moe
.
gate
.
gate
.
weight
.
data
,
group_sender
,
group
=
mp_group
...
...
@@ -49,15 +49,17 @@ def _perform_forward(
return
moe_out
,
raw_out
,
inp
.
grad
,
inp_raw
.
grad
def
_assert_numer
c
ial
(
names
,
moe_out_list
,
raw_out_list
,
rank
):
def
_assert_numeri
c
al
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
if
err
>
1e-3
:
if
err
>
precision
:
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
mo
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
raw out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
ro
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
diff ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
{}
\n
"
.
format
((
mo
-
ro
).
abs
(),
err
))
assert
False
...
...
@@ -90,6 +92,7 @@ class MyMoE(FMoE):
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
'torch.FloatTensor'
,
'torch.DoubleTensor'
,
'torch.HalfTensor'
])
def
test_fmoe_linear
(
num_expert
,
top_k
,
...
...
@@ -101,6 +104,7 @@ def test_fmoe_linear(
mp_group
,
dp_group
,
world_group
,
data_type
,
activation
=
torch
.
nn
.
functional
.
gelu
,
):
torch
.
manual_seed
(
42
+
rank
)
...
...
@@ -108,7 +112,7 @@ def test_fmoe_linear(
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
).
type
(
data_type
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
...
...
@@ -117,7 +121,7 @@ def test_fmoe_linear(
d_hidden
=
d_hidden
,
world_size
=
world_size
,
top_k
=
top_k
,
).
cuda
()
).
type
(
data_type
).
cuda
()
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
...
...
@@ -148,7 +152,7 @@ def test_fmoe_linear(
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
,
data_type
=
data_type
)
moe_out_list
=
(
...
...
@@ -198,7 +202,10 @@ def test_fmoe_linear(
"h4toh bias grad"
,
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
precision
=
5e-1
if
data_type
==
'torch.HalfTensor'
else
1e-3
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
precision
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
...
@@ -299,7 +306,7 @@ def test_fmoe(
raw_out_list
=
[
raw_out
,
raw_grad
,
raw_grad_in
]
names
=
[
"forward"
,
"backward"
,
"grad_in"
]
_assert_numer
c
ial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numeri
c
al
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
class
MyModule
(
nn
.
Module
):
...
...
@@ -375,7 +382,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
names
=
[
"mp grad"
,
"dp grad"
,
"wp grad"
]
_assert_numer
c
ial
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
_assert_numeri
c
al
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment