Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
05728524
Commit
05728524
authored
Feb 22, 2021
by
Sengxian
Browse files
Fix conflict
parents
f39f411a
5cb4e63c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
33 deletions
+21
-33
tests/benchmark_mlp.py
tests/benchmark_mlp.py
+1
-0
tests/test_numerical.py
tests/test_numerical.py
+20
-33
No files found.
tests/benchmark_mlp.py
View file @
05728524
...
...
@@ -95,6 +95,7 @@ def benchmark_mlp(MOELayer, batch_size, in_feat, hidden_feat, num_expert, top_k)
if
__name__
==
'__main__'
:
os
.
environ
[
'RANK'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_RANK'
,
'0'
)
os
.
environ
[
'WORLD_SIZE'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_SIZE'
,
'1'
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
os
.
environ
.
get
(
'OMPI_COMM_WORLD_LOCAL_RANK'
,
'0'
)
if
int
(
os
.
environ
[
'WORLD_SIZE'
])
>
1
:
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
rank
=
torch
.
distributed
.
get_rank
()
...
...
tests/test_numerical.py
View file @
05728524
...
...
@@ -32,15 +32,19 @@ def _perform_forward(
moe
.
gate
.
gate
.
bias
.
data
,
group_sender
,
group
=
mp_group
)
gate_idx
,
gate_score
=
moe
.
gate
(
inp
)
inp_repeated
=
inp
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
moe_out
=
moe
(
inp
).
mean
()
raw_out
=
moe_raw
(
inp_repeated
,
gate_idx
,
gate_score
).
mean
()
inp_raw
=
inp
.
clone
()
inp
.
requires_grad
=
True
moe_out
.
backward
()
raw_out
.
backward
()
inp_raw
.
requires_grad
=
True
gate_idx
,
gate_score
=
moe
.
gate
(
inp_raw
)
inp_repeated
=
inp_raw
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
moe_out
=
moe
(
inp
)
raw_out
=
moe_raw
(
inp_repeated
,
gate_idx
,
gate_score
)
raw_out
.
mean
().
backward
()
moe_out
.
mean
().
backward
()
return
moe_out
,
raw_out
return
moe_out
,
raw_out
,
inp
.
grad
,
inp_raw
.
grad
def
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
):
...
...
@@ -137,24 +141,12 @@ def test_fmoe_linear(
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
=
_perform_forward
(
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
moe_out_list
=
(
moe_out
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
,
)
raw_out_list
=
(
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
,
)
moe_out_list
=
moe_out
,
moe_grad_in
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
raw_out_list
=
raw_out
,
raw_grad_in
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
if
world_size
>
1
:
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
...
...
@@ -177,13 +169,8 @@ def test_fmoe_linear(
)
raw_out_list
=
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
,
"htoh4 bias grad"
,
"h4toh bias grad"
,
]
names
=
[
"output"
,
"input grad"
,
"htoh4 weight grad"
,
"h4toh weight grad"
,
"htoh4 bias grad"
,
"h4toh bias grad"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
...
@@ -254,7 +241,7 @@ def test_fmoe(
idx
].
data
=
para_tensor_gathered
[
expertID
]
moe_out
,
raw_out
=
_perform_forward
(
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
...
...
@@ -281,9 +268,9 @@ def test_fmoe(
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
raw_grad
=
raw_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
moe_out_list
=
[
moe_out
,
moe_grad
]
raw_out_list
=
[
raw_out
,
raw_grad
]
names
=
[
"forward"
,
"backward"
]
moe_out_list
=
[
moe_out
,
moe_grad
,
moe_grad_in
]
raw_out_list
=
[
raw_out
,
raw_grad
,
raw_grad_in
]
names
=
[
"forward"
,
"backward"
,
"grad_in"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment