Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
3397bc19
Unverified
Commit
3397bc19
authored
Oct 26, 2021
by
Rick Ho
Committed by
GitHub
Oct 26, 2021
Browse files
Merge pull request #84 from laekov/swipe
SWIPE balance strategy
parents
4a9ef7fd
206f267e
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
352 additions
and
46 deletions
+352
-46
cuda/balancing.cu
cuda/balancing.cu
+149
-0
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+4
-0
cuda/global_exchange.cpp
cuda/global_exchange.cpp
+29
-2
cuda/global_exchange.h
cuda/global_exchange.h
+8
-27
fmoe/gates/__init__.py
fmoe/gates/__init__.py
+2
-0
fmoe/gates/base_gate.py
fmoe/gates/base_gate.py
+4
-0
fmoe/gates/swipe_gate.py
fmoe/gates/swipe_gate.py
+49
-0
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+6
-3
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+3
-0
fmoe/megatron/patch.py
fmoe/megatron/patch.py
+7
-3
tests/test_ddp.py
tests/test_ddp.py
+11
-0
tests/test_gates.py
tests/test_gates.py
+1
-11
tests/test_swipe.py
tests/test_swipe.py
+79
-0
No files found.
cuda/balancing.cu
View file @
3397bc19
#include <cstdio>
#include "balancing.cuh"
#include "balancing.cuh"
#include "global_exchange.h"
#include <torch/extension.h>
#include <torch/extension.h>
/*
/*
...
@@ -35,3 +37,150 @@ torch::Tensor _prune_gate_by_capacity(
...
@@ -35,3 +37,150 @@ torch::Tensor _prune_gate_by_capacity(
batch_size
,
n_expert
,
n_worker
,
smgr
);
batch_size
,
n_expert
,
n_worker
,
smgr
);
return
new_gate_idx
;
return
new_gate_idx
;
}
}
template
<
class
T
>
T
*
_cudamalloc
(
size_t
sz
)
{
T
*
dptr
;
cudaMalloc
(
&
dptr
,
sz
*
sizeof
(
T
));
return
dptr
;
}
template
<
class
T
>
T
*
_h2d
(
const
T
*
hptr
,
T
*
dptr
,
size_t
sz
)
{
cudaMemcpy
(
dptr
,
hptr
,
sz
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
return
dptr
;
}
template
<
class
T
>
T
*
_h2d
(
T
*
hptr
,
size_t
sz
)
{
T
*
dptr
=
_cudamalloc
<
T
>
(
sz
);
return
_h2d
(
hptr
,
dptr
,
sz
);
}
template
<
class
T
>
T
*
_d2h
(
const
T
*
dptr
,
T
*
hptr
,
size_t
sz
)
{
cudaMemcpy
(
hptr
,
dptr
,
sz
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
);
return
hptr
;
}
template
<
class
T
>
T
*
_d2h
(
const
T
*
dptr
,
size_t
sz
)
{
T
*
hptr
=
new
T
[
sz
];
return
_d2h
(
dptr
,
hptr
,
sz
);
}
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#define UPDATE_COUNTERS(__count__) { \
if (i == rank) { \
lec[j] += (__count__); \
} \
if (j == rank) { \
gec[i] += (__count__); \
cap -= (__count__); \
} \
}
std
::
vector
<
torch
::
Tensor
>
_swipe_once
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
capacity
,
long
n_expert
,
long
n_worker
,
long
bias
)
{
auto
device_idx
=
gate_idx
.
device
().
index
();
auto
smgr
=
getCudaStreamManager
(
device_idx
);
int
rank
;
ncclCommUserRank
(
smgr
->
ncclcomm
,
&
rank
);
cudaSetDevice
(
device_idx
);
auto
capacity_new
=
capacity
.
clone
();
auto
cap
=
capacity_new
.
item
<
long
>
();
long
batch_size
=
gate_idx
.
size
(
0
);
auto
gate_idx_cpu
=
gate_idx
.
cpu
();
long
*
gidx
=
gate_idx_cpu
.
data_ptr
<
long
>
();
/* Local count and exchange */
long
*
lec
=
new
long
[
n_worker
];
memset
(
lec
,
0
,
n_worker
*
sizeof
(
long
));
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
++
lec
[
gidx
[
i
]
/
n_expert
];
}
long
*
d_lec
=
_h2d
(
lec
,
n_worker
),
*
d_gec
=
_cudamalloc
<
long
>
(
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_lec
,
d_gec
,
1
,
n_worker
,
smgr
);
long
*
gec
=
_d2h
(
d_gec
,
n_worker
);
/* Limit number of incoming samples */
long
*
drop_count
=
new
long
[
n_worker
];
memset
(
drop_count
,
0
,
n_worker
*
sizeof
(
long
));
for
(
long
i
=
0
;
i
<
n_worker
;
++
i
)
{
if
(
cap
>=
gec
[
i
])
{
drop_count
[
i
]
=
0
;
cap
-=
gec
[
i
];
}
else
{
drop_count
[
i
]
=
gec
[
i
]
-
cap
;
gec
[
i
]
=
cap
;
cap
=
0
;
}
}
/* Send limit information back */
_h2d
(
gec
,
d_gec
,
n_worker
);
fmoe_cuda_expert_exchange_impl
(
d_gec
,
d_lec
,
1
,
n_worker
,
smgr
);
_d2h
(
d_lec
,
lec
,
n_worker
);
auto
d_dropcount
=
_h2d
(
drop_count
,
n_worker
);
ncclAllReduce
(
d_dropcount
,
d_dropcount
,
n_worker
,
ncclInt64
,
ncclSum
,
smgr
->
ncclcomm
,
smgr
->
stream
());
_d2h
(
d_dropcount
,
drop_count
,
n_worker
);
auto
d_gcap
=
_cudamalloc
<
long
>
(
n_worker
);
_h2d
(
&
cap
,
d_gcap
+
rank
,
1
);
ncclAllGather
(
d_gcap
+
rank
,
d_gcap
,
1
,
ncclInt64
,
smgr
->
ncclcomm
,
smgr
->
stream
());
auto
gcap
=
_d2h
(
d_gcap
,
n_worker
);
/* Re-assign and update counters */
for
(
long
i
=
0
,
j
=
0
;
i
<
n_worker
;
++
i
)
{
while
(
drop_count
[
i
]
>
0
)
{
if
(
drop_count
[
i
]
>
gcap
[
j
])
{
drop_count
[
i
]
-=
gcap
[
j
];
UPDATE_COUNTERS
(
gcap
[
j
]);
++
j
;
}
else
{
gcap
[
j
]
-=
drop_count
[
i
];
UPDATE_COUNTERS
(
drop_count
[
i
]);
break
;
}
}
}
for
(
long
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
widx
=
gidx
[
i
]
/
n_expert
;
if
(
lec
[
widx
]
>
0
)
{
--
lec
[
widx
];
}
else
{
gidx
[
i
]
=
-
1
;
}
}
for
(
long
i
=
0
,
k
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
gidx
[
i
]
!=
-
1
)
{
continue
;
}
for
(;
lec
[
k
]
==
0
;
++
k
);
--
lec
[
k
];
gidx
[
i
]
=
k
*
n_expert
+
bias
;
}
*
capacity_new
.
data_ptr
<
long
>
()
=
cap
;
delete
[]
drop_count
;
delete
[]
lec
;
delete
[]
gec
;
delete
[]
gcap
;
cudaFree
(
d_dropcount
);
cudaFree
(
d_lec
);
cudaFree
(
d_gec
);
cudaFree
(
d_gcap
);
return
{
gate_idx_cpu
,
capacity_new
};
}
#undef UPDATE_COUNTERS
#endif
cuda/fmoe_cuda.cpp
View file @
3397bc19
...
@@ -52,6 +52,9 @@ torch::Tensor _limit_by_capacity(
...
@@ -52,6 +52,9 @@ torch::Tensor _limit_by_capacity(
torch
::
Tensor
_prune_gate_by_capacity
(
torch
::
Tensor
_prune_gate_by_capacity
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
torch
::
Tensor
gate_idx
,
torch
::
Tensor
expert_count
,
long
n_expert
,
long
n_worker
);
long
n_expert
,
long
n_worker
);
std
::
vector
<
torch
::
Tensor
>
_swipe_once
(
torch
::
Tensor
gate_idx
,
torch
::
Tensor
capacity_tensor
,
long
n_expert
,
long
n_worker
,
long
bias
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
#ifdef FMOE_USE_NCCL
#ifdef FMOE_USE_NCCL
...
@@ -59,6 +62,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -59,6 +62,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"global_scatter"
,
&
_global_scatter
,
"FastMoE global scatter (CUDA)"
);
m
.
def
(
"global_scatter"
,
&
_global_scatter
,
"FastMoE global scatter (CUDA)"
);
m
.
def
(
"global_gather"
,
&
_global_gather
,
"FastMoE global gather (CUDA)"
);
m
.
def
(
"global_gather"
,
&
_global_gather
,
"FastMoE global gather (CUDA)"
);
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
m
.
def
(
"ensure_nccl"
,
&
_ensure_nccl
,
"FastMoE ensure torch nccl comm"
);
m
.
def
(
"swipe_once"
,
&
_swipe_once
,
"SWIPE balance strategy(CUDA)"
);
#endif
#endif
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
m
.
def
(
"expert_count"
,
&
_expert_count
,
"FastMoE count gate indices (CUDA)"
);
...
...
cuda/global_exchange.cpp
View file @
3397bc19
...
@@ -5,6 +5,33 @@
...
@@ -5,6 +5,33 @@
#ifdef FMOE_USE_NCCL
#ifdef FMOE_USE_NCCL
#include <nccl.h>
#include <nccl.h>
void
fmoe_cuda_expert_exchange_impl
(
const
long
*
local_expert_count
,
long
*
global_expert_count
,
int
n_expert
,
int
world_size
,
CudaStreamManager
*
smgr
)
{
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
local_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
torch
::
Tensor
_expert_exchange
(
torch
::
Tensor
_expert_exchange
(
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
local_expert_count
,
long
n_expert
,
long
n_workers
)
{
long
n_expert
,
long
n_workers
)
{
...
@@ -31,7 +58,7 @@ torch::Tensor _global_scatter(
...
@@ -31,7 +58,7 @@ torch::Tensor _global_scatter(
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
global_input_buf
=
input_buf
.
new_empty
({
batch_size
,
in_feat
});
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
input_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input_buf
.
scalar_type
(),
"fmoe_cuda_global_scatter"
,
([
&
]
{
"fmoe_cuda_global_scatter"
,
([
&
]
{
fmoe_cuda_global_scatter_impl
<
scalar_t
>
(
fmoe_cuda_global_scatter_impl
<
scalar_t
>
(
input_buf
.
data_ptr
<
scalar_t
>
(),
input_buf
.
data_ptr
<
scalar_t
>
(),
...
@@ -57,7 +84,7 @@ torch::Tensor _global_gather(
...
@@ -57,7 +84,7 @@ torch::Tensor _global_gather(
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
local_output_buf
=
output_buf
.
new_empty
({
batch_size
,
out_feat
});
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
auto
smgr
=
getCudaStreamManager
(
output_buf
.
device
().
index
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
output_buf
.
scalar_type
(),
"fmoe_cuda_global_gather"
,
([
&
]
{
"fmoe_cuda_global_gather"
,
([
&
]
{
fmoe_cuda_global_gather_impl
<
scalar_t
>
(
fmoe_cuda_global_gather_impl
<
scalar_t
>
(
output_buf
.
data_ptr
<
scalar_t
>
(),
output_buf
.
data_ptr
<
scalar_t
>
(),
...
...
cuda/global_exchange.h
View file @
3397bc19
...
@@ -2,30 +2,11 @@
...
@@ -2,30 +2,11 @@
#ifdef FMOE_USE_NCCL
#ifdef FMOE_USE_NCCL
void
fmoe_cuda_expert_exchange_impl
(
void
fmoe_cuda_expert_exchange_impl
(
const
long
*
local_expert_count
,
const
long
*
local_expert_count
,
long
*
global_expert_count
,
long
*
global_expert_count
,
int
n_expert
,
int
world_size
,
int
n_expert
,
int
world_size
,
CudaStreamManager
*
smgr
)
{
CudaStreamManager
*
smgr
);
NCCL_SAFE_CALL
(
ncclGroupStart
());
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
NCCL_SAFE_CALL
(
ncclSend
(
local_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
NCCL_SAFE_CALL
(
ncclRecv
(
global_expert_count
+
n_expert
*
i
,
n_expert
,
ncclInt64
,
i
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
}
NCCL_SAFE_CALL
(
ncclGroupEnd
());
smgr
->
sync
(
1
);
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fmoe_cuda_global_scatter_impl
(
void
fmoe_cuda_global_scatter_impl
(
...
@@ -50,9 +31,9 @@ void fmoe_cuda_global_scatter_impl(
...
@@ -50,9 +31,9 @@ void fmoe_cuda_global_scatter_impl(
int
idx
=
i
+
j
*
n_expert
;
int
idx
=
i
+
j
*
n_expert
;
if
(
local_expert_count
[
idx
])
{
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclSend
(
NCCL_SAFE_CALL
(
ncclSend
(
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_input_buf
+
expert_ptr
[
idx
]
*
in_feat
,
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
local_expert_count
[
idx
]
*
in_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
smgr
->
stream
(
0
)));
...
@@ -106,9 +87,9 @@ void fmoe_cuda_global_gather_impl(
...
@@ -106,9 +87,9 @@ void fmoe_cuda_global_gather_impl(
}
}
if
(
local_expert_count
[
idx
])
{
if
(
local_expert_count
[
idx
])
{
NCCL_SAFE_CALL
(
ncclRecv
(
NCCL_SAFE_CALL
(
ncclRecv
(
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_output_buf
+
expert_ptr
[
idx
]
*
out_feat
,
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
local_expert_count
[
idx
]
*
out_feat
*
sizeof
(
scalar_t
),
ncclChar
,
ncclChar
,
j
,
j
,
smgr
->
ncclcomm
,
smgr
->
ncclcomm
,
smgr
->
stream
(
0
)));
smgr
->
stream
(
0
)));
...
...
fmoe/gates/__init__.py
View file @
3397bc19
...
@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate
...
@@ -7,3 +7,5 @@ from .noisy_gate import NoisyGate
from
.gshard_gate
import
GShardGate
from
.gshard_gate
import
GShardGate
from
.switch_gate
import
SwitchGate
from
.switch_gate
import
SwitchGate
from
.swipe_gate
import
SwipeGate
fmoe/gates/base_gate.py
View file @
3397bc19
...
@@ -23,3 +23,7 @@ class BaseGate(nn.Module):
...
@@ -23,3 +23,7 @@ class BaseGate(nn.Module):
if
clear
:
if
clear
:
self
.
loss
=
None
self
.
loss
=
None
return
loss
return
loss
@
property
def
has_loss
(
self
):
return
self
.
loss
is
not
None
fmoe/gates/swipe_gate.py
0 → 100644
View file @
3397bc19
r
"""
Balanced gate using SWIPE algorithm
"""
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.naive_gate
import
NaiveGate
from
fmoe.functions
import
count_by_gate
import
fmoe_cuda
as
fmoe_native
class
SwipeGate
(
NaiveGate
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
)
def
swipe_once
(
self
,
idx
,
capacity
,
bias
):
with
torch
.
no_grad
():
idx_new
,
capacity
=
fmoe_native
.
swipe_once
(
idx
,
capacity
,
self
.
num_expert
,
self
.
world_size
,
bias
)
idx_new
=
idx_new
.
to
(
idx
.
device
)
return
idx_new
,
capacity
def
forward
(
self
,
inp
):
score
=
self
.
gate
(
inp
)
orig_score
,
orig_idx
=
torch
.
topk
(
score
,
k
=
self
.
top_k
,
dim
=-
1
)
if
not
self
.
training
:
topk_val
=
F
.
softmax
(
orig_score
,
dim
=-
1
)
return
orig_idx
,
topk_val
capacity
=
torch
.
scalar_tensor
(
inp
.
shape
[
0
]
*
self
.
top_k
,
dtype
=
torch
.
long
)
topk_idxs
=
[]
topk_vals
=
[]
idx_x
=
torch
.
arange
(
inp
.
shape
[
0
],
device
=
inp
.
device
)
for
k
in
range
(
self
.
top_k
):
idx
,
capacity
=
self
.
swipe_once
(
orig_idx
[:,
k
],
capacity
,
k
%
self
.
num_expert
)
topk_vals
.
append
(
score
[
idx_x
,
idx
])
topk_idxs
.
append
(
idx
)
topk_idx
=
torch
.
stack
(
topk_idxs
).
transpose
(
0
,
1
)
topk_val
=
torch
.
stack
(
topk_vals
).
transpose
(
0
,
1
)
topk_val
=
F
.
softmax
(
topk_val
,
dim
=-
1
)
return
topk_idx
,
topk_val
fmoe/megatron/balance.py
View file @
3397bc19
...
@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration):
...
@@ -51,9 +51,12 @@ def add_balance_log(model, writer, iteration):
while
hasattr
(
model
,
'module'
):
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
model
.
module
balance_dict_tensor
=
torch
.
vstack
(
losses
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
True
)
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
True
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
for
l
in
model
.
language_model
.
transformer
.
layers
).
detach
()
if
l
.
mlp
.
gate
.
has_loss
]
if
len
(
losses
)
==
0
:
return
balance_dict_tensor
=
torch
.
vstack
(
losses
).
detach
()
world_group
=
get_torch_default_comm
()
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
torch
.
distributed
.
all_reduce
(
balance_dict_tensor
,
group
=
world_group
)
torch
.
distributed
.
all_reduce
(
balance_dict_tensor
,
group
=
world_group
)
...
...
fmoe/megatron/layers.py
View file @
3397bc19
...
@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -95,6 +95,9 @@ class MegatronMLP(FMoETransformerMLP):
elif
args
.
balance_strategy
==
"switch"
:
elif
args
.
balance_strategy
==
"switch"
:
from
fmoe.gates
import
SwitchGate
from
fmoe.gates
import
SwitchGate
gate
=
SwitchGate
gate
=
SwitchGate
elif
args
.
balance_strategy
==
"swipe"
:
from
fmoe.gates
import
SwipeGate
gate
=
SwipeGate
elif
gate
is
None
:
elif
gate
is
None
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
...
...
fmoe/megatron/patch.py
View file @
3397bc19
...
@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func):
...
@@ -20,15 +20,19 @@ def patch_forward_step(forward_step_func):
args
=
get_args
()
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
or
args
.
balance_strategy
==
'naive'
:
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
:
return
output
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
while
hasattr
(
model
,
'module'
):
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
for
l
in
model
.
language_model
.
transformer
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
if
len
(
loss_list
)
==
0
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
(
loss
,
state_dict
),
bal_loss
=
(
(
loss
,
state_dict
),
bal_loss
=
(
output
,
output
,
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
...
...
tests/test_ddp.py
View file @
3397bc19
...
@@ -5,12 +5,23 @@ from typing import Dict
...
@@ -5,12 +5,23 @@ from typing import Dict
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
from
test_numerical
import
test_fmoe
as
_test_fmoe
from
test_numerical
import
test_fmoe
as
_test_fmoe
from
test_numerical
import
test_fmoe_linear
as
_test_fmoe_linear
from
test_numerical
import
test_fmoe_linear
as
_test_fmoe_linear
from
test_numerical
import
_test_fmoe_local_ddp
from
test_numerical
import
_test_fmoe_local_ddp
def
_ensure_initialized
():
if
not
dist
.
is_initialized
():
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
):
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
):
if
torch
.
cuda
.
device_count
()
<
world_size
:
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
"No enough GPU"
)
pytest
.
skip
(
"No enough GPU"
)
...
...
tests/test_gates.py
View file @
3397bc19
...
@@ -9,17 +9,7 @@ import torch
...
@@ -9,17 +9,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
test_ddp
import
_run_distributed
from
test_ddp
import
_ensure_initialized
,
_run_distributed
def
_ensure_initialized
():
if
not
dist
.
is_initialized
():
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
...
...
tests/test_swipe.py
0 → 100644
View file @
3397bc19
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.functions
import
ensure_comm
from
fmoe.gates.swipe_gate
import
SwipeGate
from
test_ddp
import
_ensure_initialized
,
_run_distributed
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
def
test_swipe_gate
(
world_size
,
d_model
,
batch_size
,
n_expert
,
top_k
):
if
world_size
*
n_expert
<
2
:
pytest
.
skip
(
"No enough experts"
)
_run_distributed
(
'_test_swipe_gate'
,
world_size
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
,
'top_k'
:
top_k
},
script
=
__file__
)
def
_test_swipe_gate
(
d_model
,
batch_size
,
n_expert
,
top_k
):
_ensure_initialized
()
gate
=
SwipeGate
(
d_model
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
ensure_comm
(
x
,
None
)
topk_idx
,
topk_val
=
gate
(
x
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
def
test_swipe_once
(
world_size
,
batch_size
,
n_expert
):
if
world_size
*
n_expert
<
2
:
pytest
.
skip
(
"No enough experts"
)
_run_distributed
(
'_test_swipe_once'
,
world_size
,
{
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
},
script
=
__file__
)
def
_test_swipe_once
(
batch_size
,
n_expert
):
_ensure_initialized
()
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
gate
=
SwipeGate
(
4
,
n_expert
,
dist
.
get_world_size
()).
cuda
()
idx
=
torch
.
randint
(
0
,
n_expert
*
world_size
,
(
batch_size
,)).
cuda
()
capacity
=
torch
.
scalar_tensor
(
batch_size
*
2
,
dtype
=
torch
.
long
)
ensure_comm
(
idx
,
None
)
new_idx
,
new_cap
=
gate
.
swipe_once
(
idx
,
capacity
,
0
)
idx
=
torch
.
randint
(
0
,
n_expert
*
world_size
,
(
batch_size
,)).
cuda
()
new_idx
,
new_cap
=
gate
.
swipe_once
(
idx
,
new_cap
,
0
)
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
test_swipe_gate
(
8
,
4
,
8
,
4
,
2
)
# test_swipe_once(8, 800, 4)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment