Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a526f438
Commit
a526f438
authored
Jan 28, 2021
by
Rick Ho
Browse files
single node use torch cuda expert count
parent
bc8e8181
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
14 deletions
+21
-14
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+11
-11
fmoe/moe_function.py
fmoe/moe_function.py
+10
-3
No files found.
cuda/moe_compute_kernel.cu
View file @
a526f438
...
...
@@ -20,7 +20,7 @@
template
<
typename
scalar_t
>
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
int
*
offset
,
const
scalar_t
**
ptrs
)
{
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
...
...
@@ -29,7 +29,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
template
<
typename
scalar_t
>
__global__
void
batch_scatter_kernel
(
size_t
wid
,
const
int
*
pos
,
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
...
...
@@ -77,7 +77,7 @@ void moe_cuda_expert_count_impl(
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
int
*
d_pos
,
const
long
*
d_pos
,
scalar_t
*
input_buf
,
const
long
batch_size
,
const
long
in_feat
,
...
...
@@ -90,7 +90,7 @@ void moe_cuda_local_scatter_impl(
template
<
typename
scalar_t
>
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
int
*
pos
,
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
...
...
@@ -102,7 +102,7 @@ void batch_gather_kernel(size_t wid, const int* pos,
template
<
typename
scalar_t
>
void
moe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
int
*
d_pos
,
const
long
*
d_pos
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
out_feat
,
...
...
@@ -117,7 +117,7 @@ template <typename scalar_t>
void
moe_cuda_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
output_buf
,
const
size_t
in_feat
,
const
size_t
out_feat
,
...
...
@@ -152,7 +152,7 @@ void moe_cuda_backward_impl(
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
const
size_t
batch_size
,
...
...
@@ -237,7 +237,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
long
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
...
...
@@ -259,7 +259,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
...
...
@@ -293,7 +293,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
in_feat
,
out_feat
,
...
...
@@ -331,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
long
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
batch_size
,
...
...
fmoe/moe_function.py
View file @
a526f438
...
...
@@ -6,12 +6,19 @@ import fmoe_cuda
class
MOELocal
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
):
expert_count
,
pos
=
fmoe_cuda
.
expert_count
(
gate
,
weight
.
shape
[
0
])
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
],
device
=
weight
.
device
,
dtype
=
torch
.
long
)
expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc
=
expert_count
.
cpu
()
input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
output_buf
,
=
fmoe_cuda
.
forward
(
input_buf
,
weight
,
e
xpert_count
)
output_buf
,
=
fmoe_cuda
.
forward
(
input_buf
,
weight
,
e
cc
)
output
=
fmoe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
input_buf
,
gate
,
weight
,
e
xpert_count
,
pos
]
variables
=
[
input_buf
,
gate
,
weight
,
e
cc
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment