Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
87dad9d5
Unverified
Commit
87dad9d5
authored
Feb 21, 2021
by
Rick Ho
Committed by
GitHub
Feb 21, 2021
Browse files
Merge pull request #5 from laekov/bias
merge Bias
parents
ed9277f9
01464726
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
75 deletions
+121
-75
fmoe/layers.py
fmoe/layers.py
+43
-25
fmoe/transformer.py
fmoe/transformer.py
+5
-6
tests/moe.py
tests/moe.py
+8
-0
tests/test_dp.py
tests/test_dp.py
+18
-15
tests/test_numerical.py
tests/test_numerical.py
+47
-29
No files found.
fmoe/layers.py
View file @
87dad9d5
...
@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
...
@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance.
performed in parallel to increase the performance.
The FMoELinear module provides such function.
The FMoELinear module provides such function.
'''
'''
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
rank
=
0
):
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
rank
=
rank
self
.
rank
=
rank
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
...
@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
bound
=
math
.
sqrt
(
3.0
)
*
std
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
dtype
=
self
.
weight
.
dtype
for
i
in
range
(
self
.
num_expert
):
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
weight
=
rng
.
uniform
(
-
bound
,
bound
,
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
if
self
.
bias
is
not
None
:
dtype
=
dtype
,
device
=
device
)
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
r
'''
Call MOE function
Call MOE function
'''
'''
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
x
=
x
+
bias
return
x
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
)
def
mark_module_parallel_comm
(
module
,
comm
):
def
mark_module_parallel_comm
(
module
,
comm
):
...
@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class
FMoE
(
nn
.
Module
):
class
FMoE
(
nn
.
Module
):
r
'''
r
'''
A general moe implementation that supports an arbitrary module as the
expert
A general moe implementation that supports an arbitrary module as the
Either `expert` or `expert_fn` is required
.
expert
.
* `num_expert` stands for the number of experts on **each** worker.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
* `world_size` stands for the total number of workers that contains
different experts.
different experts.
...
@@ -106,12 +126,9 @@ class FMoE(nn.Module):
...
@@ -106,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
'''
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
expert_fn
=
None
):
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
d_model
=
d_model
...
@@ -125,10 +142,12 @@ class FMoE(nn.Module):
...
@@ -125,10 +142,12 @@ class FMoE(nn.Module):
self
.
mp_rank
=
mp_group
.
rank
()
self
.
mp_rank
=
mp_group
.
rank
()
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
if
expert_fn
is
None
:
if
expert
is
not
None
:
assert
expert
is
not
None
,
'Either expert or expert_fn should be set'
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
)]
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
)]
def
expert_fn
(
inp
,
fwd_expert_count
):
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
if
isinstance
(
self
.
experts
,
nn
.
Module
):
return
self
.
experts
(
inp
,
fwd_expert_count
)
outputs
=
[]
outputs
=
[]
base_idx
=
0
base_idx
=
0
for
i
in
range
(
self
.
num_expert
):
for
i
in
range
(
self
.
num_expert
):
...
@@ -137,7 +156,6 @@ class FMoE(nn.Module):
...
@@ -137,7 +156,6 @@ class FMoE(nn.Module):
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
self
.
expert_fn
=
expert_fn
def
mark_parallel_comm
(
self
):
def
mark_parallel_comm
(
self
):
r
'''
r
'''
...
...
fmoe/transformer.py
View file @
87dad9d5
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
'''
'''
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
rank
)
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
rank
)
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
...
@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
...
@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
top_k
=
2
,
pre_lnorm
=
False
pre_lnorm
=
False
):
):
def
expert_fn
(
inp
,
gate
):
return
self
.
experts
(
inp
,
gate
)
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
expert_fn
=
expert_fn
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
...
...
tests/moe.py
View file @
87dad9d5
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self
.
weight_htoh4
=
nn
.
Parameter
(
self
.
weight_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
)
)
self
.
bias_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
)
)
self
.
weight_h4toh
=
nn
.
Parameter
(
self
.
weight_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
)
)
self
.
bias_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
)
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx
=
(
gate_idx
==
i
)
idx
=
(
gate_idx
==
i
)
x
=
inp
[
idx
]
x
=
inp
[
idx
]
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
+
self
.
bias_htoh4
[
i
]
x
=
self
.
activation
(
x
)
x
=
self
.
activation
(
x
)
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
+
self
.
bias_h4toh
[
i
]
o
[
idx
]
=
x
o
[
idx
]
=
x
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
...
...
tests/test_dp.py
View file @
87dad9d5
...
@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
...
@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
n_devices
=
int
(
os
.
environ
.
get
(
"N_GPUS"
,
"2"
))
n_devices
=
int
(
os
.
environ
.
get
(
"N_GPUS"
,
"2"
))
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
top_k
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -26,22 +39,12 @@ def test_fmoe_dp(
...
@@ -26,22 +39,12 @@ def test_fmoe_dp(
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
).
cuda
()
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
top_k
,
activation
).
cuda
()
def
expert_fn
(
inp
,
gate
):
return
experts
(
inp
,
gate
)
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
1
,
mp_group
=
None
,
expert_fn
=
expert_fn
,
top_k
=
top_k
,
).
cuda
()
moe_dp
=
torch
.
nn
.
DataParallel
(
moe
,
device_ids
=
list
(
range
(
n_devices
)))
moe_dp
=
torch
.
nn
.
DataParallel
(
moe
,
device_ids
=
list
(
range
(
n_devices
)))
for
i
in
range
(
5
):
for
i
in
range
(
5
):
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
if
__name__
==
'__main__'
:
test_fmoe_dp
(
4
,
2
,
4
,
16
,
32
)
tests/test_numerical.py
View file @
87dad9d5
...
@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
...
@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
assert
False
assert
False
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -74,20 +88,8 @@ def test_fmoe_linear(
...
@@ -74,20 +88,8 @@ def test_fmoe_linear(
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
).
cuda
()
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
def
expert_fn
(
inp
,
gate
):
return
experts
(
inp
,
gate
)
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert_fn
=
expert_fn
,
top_k
=
top_k
,
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
activation
=
activation
,
...
@@ -99,38 +101,54 @@ def test_fmoe_linear(
...
@@ -99,38 +101,54 @@ def test_fmoe_linear(
).
cuda
()
).
cuda
()
if
world_size
==
1
:
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
bias_htoh4
.
data
=
moe
.
experts
.
htoh4
.
bias
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
moe
.
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
bias_h4toh
.
data
=
moe
.
experts
.
h4toh
.
bias
.
data
.
clone
()
else
:
else
:
weight_htoh4_array
=
[
weight_htoh4_array
=
[
torch
.
empty_like
(
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
moe
.
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
]
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
experts
.
htoh4
.
weight
.
data
)
bias_htoh4_array
=
[
torch
.
empty_like
(
moe
.
experts
.
htoh4
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
moe
.
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_htoh4_array
,
moe
.
experts
.
htoh4
.
bias
.
data
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
bias_htoh4
.
data
=
torch
.
cat
(
bias_htoh4_array
,
dim
=
0
)
weight_h4toh_array
=
[
weight_h4toh_array
=
[
torch
.
empty_like
(
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
moe
.
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
bias_h4toh_array
=
[
torch
.
empty_like
(
moe
.
experts
.
h4toh
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
]
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
moe
.
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_h4toh_array
,
moe
.
experts
.
h4toh
.
bias
.
data
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
=
_perform_forward
(
moe_out
,
raw_out
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
)
moe_out_list
=
moe_out
,
experts
.
htoh4
.
weight
.
grad
,
experts
.
h4toh
.
weight
.
grad
moe_out_list
=
moe_out
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
if
world_size
>
1
:
if
world_size
>
1
:
_
,
htoh4_grad
,
h4toh_grad
=
raw_out_list
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
torch
.
distributed
.
all_reduce
(
htoh4_grad
)
torch
.
distributed
.
all_reduce
(
htoh4_w_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_w_grad
)
torch
.
distributed
.
all_reduce
(
htoh4_b_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_b_grad
)
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
htoh4_grad
=
htoh4_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
htoh4_w_grad
=
htoh4_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_grad
=
h4toh_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_w_grad
=
h4toh_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
raw_out_list
=
_
,
htoh4_grad
,
h4toh_grad
htoh4_b_grad
=
htoh4_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_b_grad
=
h4toh_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
raw_out_list
=
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
]
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight
grad"
,
"htoh4 bias grad"
,
"h4toh bias
grad"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment