Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
8052b3c0
Commit
8052b3c0
authored
Oct 26, 2021
by
Rick Ho
Browse files
cpu swipe
parent
4a9ef7fd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
223 additions
and
29 deletions
+223
-29
cuda/balancing.cu
cuda/balancing.cu
+135
-0
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+4
-0
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+29
-2
cuda/global_exchange.h
cuda/global_exchange.h
+8
-27
fmoe/gates/swipe_gate.py
fmoe/gates/swipe_gate.py
+47
-0
No files found.
cuda/balancing.cu
View file @
8052b3c0
#include "balancing.cuh"
#include "global_exchange.h"
#include <torch/extension.h>
/*
...
...
@@ -35,3 +36,137 @@ torch::Tensor _prune_gate_by_capacity(
batch_size
,
n_expert
,
n_worker
,
smgr
);
return
new_gate_idx
;
}
template
<
class
T
>
T
*
_cudamalloc
(
size_t
sz
)
{
T
*
dptr
;
cudaMalloc
(
&
dptr
,
sz
*
sizeof
(
T
));
return
dptr
;
}
template
<
class
T
>
T
*
_h2d
(
const
T
*
hptr
,
T
*
dptr
,
size_t
sz
)
{
cudaMemcpy
(
dptr
,
hptr
,
sz
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
return
dptr
;
}
template
<
class
T
>
T
*
_h2d
(
T
*
hptr
,
size_t
sz
)
{
T
*
dptr
=
_cudamalloc
<
T
>
(
sz
);
return
_h2d
(
hptr
,
dptr
,
sz
);
}
template
<
class
T
>
T
*
_d2h
(
const
T
*
dptr
,
T
*
hptr
,
size_t
sz
)
{
cudaMemcpy
(
hptr
,
dptr
,
sz
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
return
hptr
;
}
template
<
class
T
>
T
*
_d2h
(
const
T
*
dptr
,
size_t
sz
)
{
T
*
hptr
=
new
T
[
sz
];
return
_d2h
(
dptr
,
hptr
,
sz
);
}
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#define UPDATE_COUNTERS(__count__) { \
if (i == rank) { \
lec[j] += (__count__); \
} \
if (j == rank) { \
gec[i] += (__count__); \
cap -= (__count__); \
} \
}
std
::
vector
<
torch
::
Tensor
>
_swipe_once
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
capacity
,
long
n_expert
,
long
n_worker
,
long
bias
)
{
auto
device_idx
=
gate_idx
.
device
().
index
();
auto
smgr
=
getCudaStreamManager
(
device_idx
);
int
rank
;
ncclCommUserRank
(
smgr
->
ncclcomm
,
&
rank
);
cudaSetDevice
(
device_idx
);
auto
cap
=
capacity
.
item
<
long
>
();
long
batch_size
=
gate_idx
.
size
(
0
);
auto
gate_idx_cpu
=
gate_idx
.
cpu
();
long
*
gidx
=
gate_idx_cpu
.
data_ptr
<
long
>
();
/* Local count and exchange */
long
*
lec
=
new
long
[
n_worker
];
memset
(
lec
,
0
,
n_worker
*
sizeof
(
long
));
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
lec
[
gidx
[
i
]
%
n_expert
];
}
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
long
*
gec
=
_d2h
(
d_gec
,
n_expert
);
/* Limit number of incoming samples */
long
*
drop_count
=
new
long
[
n_worker
];
memset
(
drop_count
,
0
,
n_worker
*
sizeof
(
long
));
for
(
long
i
=
0
;
i
<
n_expert
;
++
i
)
{
if
(
cap
>=
gec
[
i
])
{
drop_count
[
i
]
=
0
;
cap
-=
gec
[
i
];
}
else
{
drop_count
[
i
]
=
gec
[
i
]
-
cap
;
gec
[
i
]
=
cap
;
cap
=
0
;
}
}
/* Send limit information back */
_h2d
(
gec
,
d_gec
,
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_expert
,
smgr
);
_d2h
(
d_lec
,
lec
,
n_expert
);
auto
d_dropcount
=
_h2d
(
drop_count
,
n_worker
);
ncclAllReduce
(
d_dropcount
,
d_dropcount
,
n_worker
,
ncclInt64
,
ncclSum
,
smgr
->
ncclcomm
,
smgr
->
stream
());
_d2h
(
d_dropcount
,
drop_count
,
n_worker
);
auto
d_gcap
=
_cudamalloc
<
long
>
(
n_worker
);
_h2d
(
d_gcap
+
rank
,
&
cap
,
n_worker
);
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
smgr
->
ncclcomm
,
smgr
->
stream
());
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
/* Re-assign counts */
for
(
long
i
=
0
,
j
=
0
;
i
<
n_worker
;
++
i
)
{
while
(
drop_count
[
i
]
>
0
)
{
if
(
drop_count
[
i
]
>
gcap
[
j
])
{
drop_count
[
i
]
-=
gcap
[
j
];
UPDATE_COUNTERS
(
gcap
[
j
]);
++
j
;
}
else
{
gcap
[
j
]
-=
drop_count
[
i
];
UPDATE_COUNTERS
(
drop_count
[
i
]);
break
;
}
}
}
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
widx
=
gidx
[
i
]
%
n_expert
;
if
(
lec
[
widx
]
>
0
)
{
--
lec
[
widx
];
}
else
{
gidx
[
i
]
=
-
1
;
}
}
for
(
long
i
=
0
,
k
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
gidx
[
i
]
!=
-
1
)
{
continue
;
}
for
(;
lec
[
k
]
==
0
;
++
k
);
--
lec
[
gidx
[
i
]
=
k
*
n_expert
+
bias
];
}
return
{
gate_idx_cpu
,
capacity
};
}
#undef UPDATE_COUNTERS
#endif
cuda/fmoe_cuda.cpp
View file @
8052b3c0
...
...
@@ -52,6 +52,9 @@ torch::Tensor _limit_by_capacity(
torch
::
Tensor
_prune_gate_by_capacity
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
long
n_expert
,
long
n_worker
);
std
::
vector
<
torch
::
Tensor
>
_swipe_once
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
capacity_tensor
,
long
n_expert
,
long
n_worker
,
long
bias
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef FMOE_USE_NCCL
...
...
@@ -59,6 +62,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"global_scatter"
,
&
_global_scatter
,
"FastMoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
_global_gather
,
"FastMoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
m
.
def
(
"swipe_once"
,
&
_swipe_once
,
"SWIPE balance strategy(CUDA)"
);
#endif
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
...
...
cuda/global_exchange.cpp
View file @
8052b3c0
...
...
@@ -5,6 +5,33 @@
#ifdef FMOE_USE_NCCL
#include <nccl.h>
void
fmoe_cuda_expert_exchange_impl
(
const
long
*
local_expert_count
,
long
*
global_expert_count
,
int
n_expert
,
int
world_size
,
CudaStreamManager
*
smgr
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
local_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
torch
::
Tensor
_expert_exchange
(
torch
::
Tensor
local_expert_count
,
long
n_expert
,
long
n_workers
)
{
...
...
@@ -31,7 +58,7 @@ torch::Tensor _global_scatter(
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"fmoe_cuda_global_scatter"
,
([
&
]
{
fmoe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
...
...
@@ -57,7 +84,7 @@ torch::Tensor _global_gather(
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"fmoe_cuda_global_gather"
,
([
&
]
{
fmoe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
...
...
cuda/global_exchange.h
View file @
8052b3c0
...
...
@@ -2,30 +2,11 @@
#ifdef FMOE_USE_NCCL
void
fmoe_cuda_expert_exchange_impl
(
const
long
*
local_expert_count
,
long
*
global_expert_count
,
const
long
*
local_expert_count
,
long
*
global_expert_count
,
int
n_expert
,
int
world_size
,
CudaStreamManager
*
smgr
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
local_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
CudaStreamManager
*
smgr
);
template
<
typename
scalar_t
>
void
fmoe_cuda_global_scatter_impl
(
...
...
@@ -50,9 +31,9 @@ void fmoe_cuda_global_scatter_impl(
int
idx
=
i
+
j
*
n_expert
;
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
...
...
@@ -106,9 +87,9 @@ void fmoe_cuda_global_gather_impl(
}
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
...
...
fmoe/gates/swipe_gate.py
0 → 100644
View file @
8052b3c0
r
"""
Balanced gate using SWIPE algorithm
"""
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
fmoe.functions
import
count_by_gate
import
fmoe_cuda
as
fmoe_native
class
SwipeGate
(
NaiveGate
):
requires_moe_group
=
True
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
topk
=
2
):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
)
def
swipe_once
(
self
,
idx
,
capacity
):
with
torch
.
no_grad
():
idx_new
,
capacity
=
fmoe_native
.
swipe_once
(
idx
,
capacity
,
self
.
num_expert
,
self
.
world_size
)
idx_new
=
idx_new
.
to
(
idx
.
device
)
return
idx_new
,
capacity
def
forward
(
self
,
inp
):
score
=
self
.
gate
(
inp
)
_
,
orig_idx
=
torch
.
topk
(
gate_score
,
k
=
self
.
top_k
,
dim
=-
1
)
if
not
self
.
training
:
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
return
topk_idx
,
topk_val
capacity
=
torch
.
scalar_tensor
(
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
long
)
topk_idxs
=
[]
for
k
in
range
(
self
.
top_k
):
idx
,
capacity
=
self
.
swipe_once
(
orig_idx
[:,
k
],
capacity
)
topk_idxs
.
append
(
idx
)
topk_idx
=
torch
.
stack
(
topk_idxs
).
transpose
(
0
,
1
)
topk_val
=
gate_score
[
idx_x
,
topk_idx
.
view
(
-
1
)].
view
(
-
1
,
self
.
top_k
)
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
return
topk_idx
,
topk_val
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment