Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a526f438
"src/array/cuda/array_cumsum.hip" did not exist on "0ff7127a0fff730f3c41a8ea3e967c1155993a2f"
Commit
a526f438
authored
Jan 28, 2021
by
Rick Ho
Browse files
single node use torch cuda expert count
parent
bc8e8181
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
14 deletions
+21
-14
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+11
-11
fmoe/moe_function.py
fmoe/moe_function.py
+10
-3
No files found.
cuda/moe_compute_kernel.cu
View file @
a526f438
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
int
*
offset
,
const
scalar_t
**
ptrs
)
{
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
...
@@ -29,7 +29,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
...
@@ -29,7 +29,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
batch_scatter_kernel
(
size_t
wid
,
const
int
*
pos
,
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
...
@@ -77,7 +77,7 @@ void moe_cuda_expert_count_impl(
...
@@ -77,7 +77,7 @@ void moe_cuda_expert_count_impl(
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_local_scatter_impl
(
void
moe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
scalar_t
*
input
,
const
int
*
d_pos
,
const
long
*
d_pos
,
scalar_t
*
input_buf
,
scalar_t
*
input_buf
,
const
long
batch_size
,
const
long
batch_size
,
const
long
in_feat
,
const
long
in_feat
,
...
@@ -90,7 +90,7 @@ void moe_cuda_local_scatter_impl(
...
@@ -90,7 +90,7 @@ void moe_cuda_local_scatter_impl(
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
int
*
pos
,
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
blockIdx
.
x
;
...
@@ -102,7 +102,7 @@ void batch_gather_kernel(size_t wid, const int* pos,
...
@@ -102,7 +102,7 @@ void batch_gather_kernel(size_t wid, const int* pos,
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
moe_cuda_local_gather_impl
(
void
moe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
scalar_t
*
output_buf
,
const
int
*
d_pos
,
const
long
*
d_pos
,
scalar_t
*
output
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
batch_size
,
const
size_t
out_feat
,
const
size_t
out_feat
,
...
@@ -117,7 +117,7 @@ template <typename scalar_t>
...
@@ -117,7 +117,7 @@ template <typename scalar_t>
void
moe_cuda_forward_impl
(
void
moe_cuda_forward_impl
(
const
scalar_t
*
input_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
output_buf
,
scalar_t
*
output_buf
,
const
size_t
in_feat
,
const
size_t
in_feat
,
const
size_t
out_feat
,
const
size_t
out_feat
,
...
@@ -152,7 +152,7 @@ void moe_cuda_backward_impl(
...
@@ -152,7 +152,7 @@ void moe_cuda_backward_impl(
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
grad_output_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
input_buf
,
const
scalar_t
*
weight
,
const
scalar_t
*
weight
,
const
int
*
expert_count
,
const
long
*
expert_count
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_input_buf
,
scalar_t
*
grad_weight
,
scalar_t
*
grad_weight
,
const
size_t
batch_size
,
const
size_t
batch_size
,
...
@@ -237,7 +237,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
...
@@ -237,7 +237,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
([
&
]
{
([
&
]
{
moe_cuda_local_scatter_impl
<
scalar_t
>
(
moe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
long
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
in_feat
,
in_feat
,
...
@@ -259,7 +259,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
...
@@ -259,7 +259,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
([
&
]
{
([
&
]
{
moe_cuda_local_gather_impl
<
scalar_t
>
(
moe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
out_feat
,
out_feat
,
...
@@ -293,7 +293,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
...
@@ -293,7 +293,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl
<
scalar_t
>
(
moe_cuda_forward_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
in_feat
,
in_feat
,
out_feat
,
out_feat
,
...
@@ -331,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
...
@@ -331,7 +331,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
grad_output_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
expert_count
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
long
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_input_buf
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
grad_weight
.
data_ptr
<
scalar_t
>
(),
batch_size
,
batch_size
,
...
...
fmoe/moe_function.py
View file @
a526f438
...
@@ -6,12 +6,19 @@ import fmoe_cuda
...
@@ -6,12 +6,19 @@ import fmoe_cuda
class
MOELocal
(
Function
):
class
MOELocal
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
gate
,
weight
):
def
forward
(
ctx
,
inp
,
gate
,
weight
):
expert_count
,
pos
=
fmoe_cuda
.
expert_count
(
gate
,
weight
.
shape
[
0
])
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
expert_count
=
torch
.
zeros
(
weight
.
shape
[
0
],
device
=
weight
.
device
,
dtype
=
torch
.
long
)
expert_count
.
index_put_
((
gate_idx
.
long
(),
),
gate_count
)
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc
=
expert_count
.
cpu
()
input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
input_buf
,
=
fmoe_cuda
.
local_scatter
(
inp
,
pos
)
output_buf
,
=
fmoe_cuda
.
forward
(
input_buf
,
weight
,
e
xpert_count
)
output_buf
,
=
fmoe_cuda
.
forward
(
input_buf
,
weight
,
e
cc
)
output
=
fmoe_cuda
.
local_gather
(
output_buf
,
pos
)
output
=
fmoe_cuda
.
local_gather
(
output_buf
,
pos
)
variables
=
[
input_buf
,
gate
,
weight
,
e
xpert_count
,
pos
]
variables
=
[
input_buf
,
gate
,
weight
,
e
cc
,
pos
]
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
output
[
0
]
return
output
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment