Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d2039fc7
Commit
d2039fc7
authored
Feb 01, 2021
by
Rick Ho
Browse files
swap local scatter and gather kernel functions
parent
14c0eab4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
cuda/moe_compute_kernel.cu
cuda/moe_compute_kernel.cu
+4
-4
fmoe/fmoe_functions.py
fmoe/fmoe_functions.py
+4
-4
No files found.
cuda/moe_compute_kernel.cu
View file @
d2039fc7
...
@@ -31,8 +31,8 @@ template <typename scalar_t>
...
@@ -31,8 +31,8 @@ template <typename scalar_t>
__global__
__global__
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
inbuf
+=
wid
*
pos
[
blockIdx
.
x
]
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
]
;
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
oubuf
[
i
]
=
inbuf
[
i
];
}
}
...
@@ -92,8 +92,8 @@ template <typename scalar_t>
...
@@ -92,8 +92,8 @@ template <typename scalar_t>
__global__
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
]
;
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
]
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
oubuf
[
i
]
=
inbuf
[
i
];
}
}
...
...
fmoe/fmoe_functions.py
View file @
d2039fc7
...
@@ -29,7 +29,7 @@ class MOEScatter(Function):
...
@@ -29,7 +29,7 @@ class MOEScatter(Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
def
forward
(
ctx
,
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
):
fwd_batch_size
,
world_size
):
local_input_buf
,
=
fmoe_cuda
.
local_
g
at
h
er
(
inp
,
pos
)
local_input_buf
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
inp
,
pos
)
if
world_size
>
1
:
if
world_size
>
1
:
global_input_buf
,
=
fmoe_cuda
.
global_scatter
(
local_input_buf
,
global_input_buf
,
=
fmoe_cuda
.
global_scatter
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
...
@@ -52,7 +52,7 @@ class MOEScatter(Function):
...
@@ -52,7 +52,7 @@ class MOEScatter(Function):
local_batch_size
,
world_size
)
local_batch_size
,
world_size
)
else
:
else
:
local_grad_in
=
global_grad_in
local_grad_in
=
global_grad_in
grad_in
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
local_grad_in
,
pos
)
grad_in
,
=
fmoe_cuda
.
local_
g
at
h
er
(
local_grad_in
,
pos
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
return
grad_in
,
None
,
None
,
None
,
None
,
None
...
@@ -83,7 +83,7 @@ class MOEGather(Function):
...
@@ -83,7 +83,7 @@ class MOEGather(Function):
local_batch_size
,
world_size
)
local_batch_size
,
world_size
)
else
:
else
:
local_output_buf
=
global_output_buf
local_output_buf
=
global_output_buf
output
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
local_output_buf
,
pos
)
output
,
=
fmoe_cuda
.
local_
g
at
h
er
(
local_output_buf
,
pos
)
ctx
.
moe_args
=
local_batch_size
,
global_output_buf
.
shape
[
0
],
world_size
ctx
.
moe_args
=
local_batch_size
,
global_output_buf
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
...
@@ -94,7 +94,7 @@ class MOEGather(Function):
...
@@ -94,7 +94,7 @@ class MOEGather(Function):
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
pos
,
local_expert_count
,
global_expert_count
=
ctx
.
saved_tensors
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
local_batch_size
,
fwd_batch_size
,
world_size
=
ctx
.
moe_args
grad_out_buf
,
=
fmoe_cuda
.
local_
g
at
h
er
(
grad_out
.
contiguous
(),
pos
)
grad_out_buf
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
grad_out
.
contiguous
(),
pos
)
if
world_size
>
1
:
if
world_size
>
1
:
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
global_grad_out_buf
,
=
fmoe_cuda
.
global_scatter
(
grad_out_buf
,
local_expert_count
,
global_expert_count
,
local_expert_count
,
global_expert_count
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment