Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
980cf4b6
Commit
980cf4b6
authored
May 18, 2021
by
Rick Ho
Browse files
use customized pos assignment kernel to support -1
parent
a468db2b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
63 additions
and
79 deletions
+63
-79
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+5
-4
cuda/local_exchange.cu
cuda/local_exchange.cu
+17
-22
cuda/local_exchange.cuh
cuda/local_exchange.cuh
+25
-46
fmoe/functions.py
fmoe/functions.py
+14
-5
fmoe/gates/naive_gate.py
fmoe/gates/naive_gate.py
+0
-1
fmoe/gates/utils.py
fmoe/gates/utils.py
+2
-1
No files found.
cuda/fmoe_cuda.cpp
View file @
980cf4b6
...
...
@@ -22,15 +22,16 @@ void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // FMOE_USE_NCCL
// local_exchange
std
::
vector
<
torch
::
Tensor
>
_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
);
std
::
vector
<
torch
::
Tensor
>
_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
);
std
::
vector
<
torch
::
Tensor
>
_local_gather
(
torch
::
Tensor
output_buf
,
torch
::
Tensor
pos
);
void
_assign_pos
(
torch
::
Tensor
cum_count
,
torch
::
Tensor
gate
,
torch
::
Tensor
pos
);
// parallel_linear
std
::
vector
<
torch
::
Tensor
>
_linear_forward
(
...
...
@@ -59,9 +60,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
#endif
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE expert count (CUDA)"
);
m
.
def
(
"local_scatter"
,
&
_local_scatter
,
"FastMoE local scatter (CUDA)"
);
m
.
def
(
"local_gather"
,
&
_local_gather
,
"FastMoE local gather (CUDA)"
);
m
.
def
(
"assign_pos_"
,
&
_assign_pos
,
"FastMoE assign pos by gate(CUDA)"
);
m
.
def
(
"linear_forward"
,
&
_linear_forward
,
"FastMoE forward (CUDA)"
);
m
.
def
(
"linear_backward"
,
&
_linear_backward
,
"FastMoE backward (CUDA)"
);
...
...
cuda/local_exchange.cu
View file @
980cf4b6
...
...
@@ -2,28 +2,6 @@
#include "utils/fmoe_utils.h"
#include <torch/extension.h>
std
::
vector
<
torch
::
Tensor
>
_expert_count
(
torch
::
Tensor
gate
,
size_t
num_expert
)
{
const
auto
batch_size
=
gate
.
size
(
0
);
auto
ec_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
);
auto
expert_count
=
torch
::
empty
(
num_expert
,
ec_options
);
auto
pos_options
=
torch
::
TensorOptions
()
.
device
(
gate
.
device
())
.
dtype
(
torch
::
kInt32
);
auto
pos
=
torch
::
empty
(
batch_size
,
pos_options
);
fmoe_cuda_expert_count_impl
(
gate
.
data_ptr
<
int
>
(),
expert_count
.
data_ptr
<
int
>
(),
pos
.
data_ptr
<
int
>
(),
num_expert
,
batch_size
);
return
{
expert_count
,
pos
};
}
std
::
vector
<
torch
::
Tensor
>
_local_scatter
(
torch
::
Tensor
input
,
torch
::
Tensor
pos
)
{
...
...
@@ -73,3 +51,20 @@ std::vector<torch::Tensor> _local_gather(
}));
return
{
output
,};
}
void
_assign_pos
(
torch
::
Tensor
cum_count
,
torch
::
Tensor
gate
,
torch
::
Tensor
pos
)
{
auto
smgr
=
getCudaStreamManager
(
cum_count
.
device
().
index
());
auto
gate_shp
=
gate
.
sizes
();
size_t
batch_size
=
gate_shp
[
0
],
topk
=
1
;
if
(
gate_shp
.
size
()
==
2
)
{
topk
=
gate_shp
[
1
];
}
fmoe_cuda_assign_pos_impl
(
cum_count
.
data_ptr
<
int
>
(),
gate
.
data_ptr
<
long
>
(),
pos
.
data_ptr
<
long
>
(),
batch_size
,
topk
,
smgr
);
}
cuda/local_exchange.cuh
View file @
980cf4b6
#include "stream_manager.h"
#include "utils/helper_cuda.h"
template
<
typename
scalar_t
>
__global__
void
generate_ptr_offset_kernel
(
size_t
n
,
const
scalar_t
*
base
,
size_t
stride
,
const
long
*
offset
,
const
scalar_t
**
ptrs
)
{
size_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
<
n
)
{
ptrs
[
idx
]
=
base
+
stride
*
offset
[
idx
];
}
}
#include "utils/fmoe_utils.h"
template
<
typename
scalar_t
>
__global__
...
...
@@ -22,42 +13,6 @@ void batch_scatter_kernel(size_t wid, const long* pos,
}
}
void
fmoe_cuda_expert_count_impl
(
const
int
*
d_gate
,
int
*
expert_count
,
int
*
d_pos
,
const
size_t
num_expert
,
const
size_t
batch_size
)
{
int
*
gate
=
new
int
[
batch_size
];
int
*
expert_ptr
=
new
int
[
num_expert
];
memset
(
expert_count
,
0
,
sizeof
(
int
)
*
num_expert
);
checkCudaErrors
(
cudaMemcpy
(
gate
,
d_gate
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyDeviceToHost
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
expert_count
[
gate
[
i
]];
}
expert_ptr
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
num_expert
;
++
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
]
+
expert_count
[
i
-
1
];
}
int
*
pos
=
new
int
[
batch_size
];
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
pos
[
i
]
=
expert_ptr
[
gate
[
i
]]
++
;
}
for
(
int
i
=
num_expert
-
1
;
i
>
0
;
--
i
)
{
expert_ptr
[
i
]
=
expert_ptr
[
i
-
1
];
}
expert_ptr
[
0
]
=
0
;
checkCudaErrors
(
cudaMemcpy
(
d_pos
,
pos
,
sizeof
(
int
)
*
batch_size
,
cudaMemcpyHostToDevice
));
delete
[]
gate
;
delete
[]
expert_ptr
;
}
template
<
typename
scalar_t
>
void
fmoe_cuda_local_scatter_impl
(
const
scalar_t
*
input
,
...
...
@@ -96,3 +51,27 @@ void fmoe_cuda_local_gather_impl(
output
);
smgr
->
sync
(
1
);
}
__global__
void
assign_pos_kernel
(
int
*
cum_count
,
const
long
*
gate
,
long
*
pos
,
size_t
numel
,
size_t
topk
)
{
size_t
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
numel
)
{
long
gate_idx
=
gate
[
idx
];
if
(
gate_idx
>
-
1
)
{
int
p
=
atomicSub
(
cum_count
+
gate_idx
,
1
);
pos
[
p
]
=
(
long
)
idx
;
}
}
}
void
fmoe_cuda_assign_pos_impl
(
int
*
cum_count
,
const
long
*
gate
,
long
*
pos
,
const
size_t
batch_size
,
const
size_t
topk
,
CudaStreamManager
*
smgr
)
{
size_t
numel
=
batch_size
*
topk
;
assign_pos_kernel
<<<
CEIL
(
numel
,
256
),
256
,
0
,
smgr
->
stream
(
0
)
>>>
(
cum_count
,
gate
,
pos
,
numel
,
topk
);
smgr
->
sync
(
1
);
}
fmoe/functions.py
View file @
980cf4b6
...
...
@@ -16,15 +16,17 @@ def _ensure_nccl(t, comm=None):
fmoe_cuda
.
ensure_nccl
(
comm
,
t
)
def
count_by_gate
(
gate
,
num_expert
,
world_size
):
# TODO: support -1 in gate, which means ignore this input
def
count_by_gate
(
gate
,
num_expert
,
world_size
,
require_pos
=
True
):
with
torch
.
no_grad
():
_
,
pos
=
torch
.
sort
(
gate
)
gate_idx
,
gate_count
=
torch
.
unique
(
gate
,
return_counts
=
True
)
flatten_gate
=
gate
.
view
(
-
1
)
eff_gate
=
flatten_gate
[
flatten_gate
!=
-
1
]
local_expert_count
=
torch
.
zeros
(
num_expert
*
world_size
,
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_put_
((
gate_idx
.
long
(),),
gate_count
)
ones
=
torch
.
ones
(
eff_gate
.
numel
(),
device
=
gate
.
device
,
dtype
=
torch
.
long
)
local_expert_count
.
index_add_
(
0
,
eff_gate
,
ones
)
if
world_size
>
1
:
_ensure_nccl
(
gate
)
...
...
@@ -33,6 +35,13 @@ def count_by_gate(gate, num_expert, world_size):
)
else
:
global_expert_count
=
local_expert_count
if
not
require_pos
:
pos
=
None
else
:
lec_cum
=
torch
.
cumsum
(
local_expert_count
,
dim
=
0
).
int
()
pos_size
=
lec_cum
[
-
1
].
item
()
pos
=
torch
.
empty
((
pos_size
,),
device
=
gate
.
device
,
dtype
=
torch
.
long
)
fmoe_cuda
.
assign_pos_
(
lec_cum
,
gate
,
pos
)
return
pos
,
local_expert_count
,
global_expert_count
...
...
fmoe/gates/naive_gate.py
View file @
980cf4b6
...
...
@@ -36,7 +36,6 @@ class NaiveGate(BaseGate):
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
if
return_all_scores
:
return
gate_top_k_idx
,
gate_top_k_val
,
gate
...
...
fmoe/gates/utils.py
View file @
980cf4b6
...
...
@@ -10,7 +10,8 @@ def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
capacity
=
torch
.
ones
(
num_expert
,
dtype
=
torch
.
int32
,
device
=
topk_idx
.
device
)
*
capacity
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
.
reshape
(
-
1
),
num_expert
,
world_size
)
pos
,
lec
,
gec
=
count_by_gate
(
topk_idx
,
num_expert
,
world_size
,
require_pos
=
False
)
new_gec
,
=
fmoe_native
.
limit_by_capacity
(
gec
,
capacity
,
num_expert
,
world_size
)
if
world_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment