Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
52e57316
Commit
52e57316
authored
May 19, 2021
by
Rick Ho
Browse files
remove local scatter and gather in cuda
parent
414a2f86
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
111 deletions
+3
-111
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+0
-8
cuda/local_exchange.cu
cuda/local_exchange.cu
+0
-50
cuda/local_exchange.cuh
cuda/local_exchange.cuh
+3
-53
No files found.
cuda/fmoe_cuda.cpp
View file @
52e57316
...
@@ -22,12 +22,6 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
...
@@ -22,12 +22,6 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // FMOE_USE_NCCL
#endif // FMOE_USE_NCCL
// local_exchange
// local_exchange
std
::
vector
<
torch
::
Tensor
>
_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
);
void
_assign_pos
(
void
_assign_pos
(
torch
::
Tensor
cum_count
,
torch
::
Tensor
cum_count
,
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
...
@@ -60,8 +54,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -60,8 +54,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
#endif
#endif
m
.
def
(
"local_scatter"
,
&
_local_scatter
,
"FastMoE local scatter (CUDA)"
);
m
.
def
(
"local_gather"
,
&
_local_gather
,
"FastMoE local gather (CUDA)"
);
m
.
def
(
"assign_pos_"
,
&
_assign_pos
,
"FastMoE assign pos by gate(CUDA)"
);
m
.
def
(
"assign_pos_"
,
&
_assign_pos
,
"FastMoE assign pos by gate(CUDA)"
);
m
.
def
(
"linear_forward"
,
&
_linear_forward
,
"FastMoE forward (CUDA)"
);
m
.
def
(
"linear_forward"
,
&
_linear_forward
,
"FastMoE forward (CUDA)"
);
...
...
cuda/local_exchange.cu
View file @
52e57316
...
@@ -2,56 +2,6 @@
...
@@ -2,56 +2,6 @@
#include "utils/fmoe_utils.h"
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
#include <torch/extension.h>
std
::
vector
<
torch
::
Tensor
>
_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
input
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
in_feat
=
input
.
size
(
1
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
());
auto
input_buf
=
torch
::
empty
({
batch_size
,
in_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
scalar_type
(),
"fmoe_local_scatter"
,
([
&
]
{
fmoe_cuda_local_scatter_impl
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
batch_size
,
in_feat
,
smgr
);
}));
return
{
input_buf
,};
}
std
::
vector
<
torch
::
Tensor
>
_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
const
auto
batch_size
=
pos
.
size
(
0
);
const
auto
out_feat
=
output_buf
.
size
(
1
);
auto
opt
=
torch
::
TensorOptions
()
.
dtype
(
output_buf
.
dtype
())
.
device
(
output_buf
.
device
());
auto
output
=
torch
::
empty
({
batch_size
,
out_feat
},
opt
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"fmoe_local_gather"
,
([
&
]
{
fmoe_cuda_local_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
pos
.
data_ptr
<
long
>
(),
output
.
data_ptr
<
scalar_t
>
(),
batch_size
,
out_feat
,
smgr
);
}));
return
{
output
,};
}
void
_assign_pos
(
void
_assign_pos
(
torch
::
Tensor
cum_count
,
torch
::
Tensor
cum_count
,
torch
::
Tensor
gate
,
torch
::
Tensor
gate
,
...
...
cuda/local_exchange.cuh
View file @
52e57316
...
@@ -2,56 +2,6 @@
...
@@ -2,56 +2,6 @@
#include "utils/helper_cuda.h"
#include "utils/helper_cuda.h"
#include "utils/fmoe_utils.h"
#include "utils/fmoe_utils.h"
template
<
typename
scalar_t
>
__global__
void
batch_scatter_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
pos
[
blockIdx
.
x
];
oubuf
+=
wid
*
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
void
fmoe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
const
long
*
d_pos
,
scalar_t
*
input_buf
,
const
long
batch_size
,
const
long
in_feat
,
CudaStreamManager
*
smgr
)
{
batch_scatter_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
in_feat
,
d_pos
,
input
,
input_buf
);
smgr
->
sync
(
1
);
}
template
<
typename
scalar_t
>
__global__
void
batch_gather_kernel
(
size_t
wid
,
const
long
*
pos
,
const
scalar_t
*
inbuf
,
scalar_t
*
oubuf
)
{
inbuf
+=
wid
*
blockIdx
.
x
;
oubuf
+=
wid
*
pos
[
blockIdx
.
x
];
for
(
int
i
=
threadIdx
.
x
;
i
<
wid
;
i
+=
blockDim
.
x
)
{
oubuf
[
i
]
=
inbuf
[
i
];
}
}
template
<
typename
scalar_t
>
void
fmoe_cuda_local_gather_impl
(
const
scalar_t
*
output_buf
,
const
long
*
d_pos
,
scalar_t
*
output
,
const
size_t
batch_size
,
const
size_t
out_feat
,
CudaStreamManager
*
smgr
)
{
batch_gather_kernel
<
scalar_t
>
<<<
batch_size
,
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
out_feat
,
d_pos
,
output_buf
,
output
);
smgr
->
sync
(
1
);
}
__global__
__global__
void
assign_pos_kernel
(
int
*
cum_count
,
const
long
*
gate
,
long
*
pos
,
void
assign_pos_kernel
(
int
*
cum_count
,
const
long
*
gate
,
long
*
pos
,
size_t
numel
,
size_t
topk
)
{
size_t
numel
,
size_t
topk
)
{
...
@@ -60,7 +10,7 @@ void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
...
@@ -60,7 +10,7 @@ void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
long
gate_idx
=
gate
[
idx
];
long
gate_idx
=
gate
[
idx
];
if
(
gate_idx
>
-
1
)
{
if
(
gate_idx
>
-
1
)
{
int
p
=
atomicSub
(
cum_count
+
gate_idx
,
1
);
int
p
=
atomicSub
(
cum_count
+
gate_idx
,
1
);
pos
[
p
]
=
(
long
)
idx
;
pos
[
p
-
1
]
=
(
long
)
idx
;
}
}
}
}
}
}
...
@@ -71,7 +21,7 @@ void fmoe_cuda_assign_pos_impl(
...
@@ -71,7 +21,7 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
)
{
size_t
numel
=
batch_size
*
topk
;
size_t
numel
=
batch_size
*
topk
;
assign_pos_kernel
assign_pos_kernel
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
cum_count
,
gate
,
pos
,
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
stream
(
0
)
>>>
numel
,
topk
);
(
cum_count
,
gate
,
pos
,
numel
,
topk
);
smgr
->
sync
(
1
);
smgr
->
sync
(
1
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment