Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
87dad9d5
"...text-generation-inference.git" did not exist on "6db3bcb700e62134b35d87169e88907543583a16"
Unverified
Commit
87dad9d5
authored
Feb 21, 2021
by
Rick Ho
Committed by
GitHub
Feb 21, 2021
Browse files
Merge pull request #5 from laekov/bias
merge Bias
parents
ed9277f9
01464726
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
75 deletions
+121
-75
fmoe/layers.py
fmoe/layers.py
+43
-25
fmoe/transformer.py
fmoe/transformer.py
+5
-6
tests/moe.py
tests/moe.py
+8
-0
tests/test_dp.py
tests/test_dp.py
+18
-15
tests/test_numerical.py
tests/test_numerical.py
+47
-29
No files found.
fmoe/layers.py
View file @
87dad9d5
...
@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
...
@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance.
performed in parallel to increase the performance.
The FMoELinear module provides such function.
The FMoELinear module provides such function.
'''
'''
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
rank
=
0
):
def
__init__
(
self
,
num_expert
:
int
,
in_feat
:
int
,
out_feat
:
int
,
bias
:
bool
=
True
,
rank
:
int
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
out_feat
=
out_feat
self
.
rank
=
rank
self
.
rank
=
rank
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
...
@@ -41,17 +46,32 @@ class FMoELinear(nn.Module):
bound
=
math
.
sqrt
(
3.0
)
*
std
bound
=
math
.
sqrt
(
3.0
)
*
std
device
=
self
.
weight
.
device
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
dtype
=
self
.
weight
.
dtype
for
i
in
range
(
self
.
num_expert
):
weight
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
weight
.
size
()))
weight
=
rng
.
uniform
(
-
bound
,
bound
,
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
size
=
tuple
(
self
.
weight
[
i
].
size
()))
self
.
weight
.
data
[
i
]
=
torch
.
tensor
(
weight
,
if
self
.
bias
is
not
None
:
dtype
=
dtype
,
device
=
device
)
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
self
.
weight
[
0
])
bound
=
1
/
math
.
sqrt
(
fan_in
)
bias
=
rng
.
uniform
(
-
bound
,
bound
,
size
=
tuple
(
self
.
bias
.
size
()))
self
.
bias
.
data
=
torch
.
tensor
(
bias
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
'''
r
'''
Call MOE function
Call MOE function
'''
'''
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
x
=
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
if
self
.
bias
is
not
None
:
bias
=
torch
.
repeat_interleave
(
self
.
bias
,
fwd_expert_count
.
to
(
self
.
bias
.
device
),
dim
=
0
)
x
=
x
+
bias
return
x
def
extra_repr
(
self
)
->
str
:
return
'num_expert={}, in_features={},
\
out_features={}, bias={}, rank={}'
.
format
(
self
.
num_expert
,
self
.
in_feat
,
self
.
out_feat
,
self
.
bias
is
not
None
,
self
.
rank
)
def
mark_module_parallel_comm
(
module
,
comm
):
def
mark_module_parallel_comm
(
module
,
comm
):
...
@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
...
@@ -92,8 +112,8 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
class
FMoE
(
nn
.
Module
):
class
FMoE
(
nn
.
Module
):
r
'''
r
'''
A general moe implementation that supports an arbitrary module as the
expert
A general moe implementation that supports an arbitrary module as the
Either `expert` or `expert_fn` is required
.
expert
.
* `num_expert` stands for the number of experts on **each** worker.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
* `world_size` stands for the total number of workers that contains
different experts.
different experts.
...
@@ -106,12 +126,9 @@ class FMoE(nn.Module):
...
@@ -106,12 +126,9 @@ class FMoE(nn.Module):
* `gate` is a gate class which can found in `fmoe.gates`.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
'''
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
,
expert_fn
=
None
):
top_k
=
2
,
gate
=
NaiveGate
,
expert
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
num_expert
=
num_expert
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
d_model
=
d_model
...
@@ -125,19 +142,20 @@ class FMoE(nn.Module):
...
@@ -125,19 +142,20 @@ class FMoE(nn.Module):
self
.
mp_rank
=
mp_group
.
rank
()
self
.
mp_rank
=
mp_group
.
rank
()
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
if
expert_fn
is
None
:
if
expert
is
not
None
:
assert
expert
is
not
None
,
'Either expert or expert_fn should be set'
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
)]
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
)]
def
expert_fn
(
inp
,
fwd_expert_count
):
outputs
=
[]
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
base_idx
=
0
if
isinstance
(
self
.
experts
,
nn
.
Module
):
for
i
in
range
(
self
.
num_expert
):
return
self
.
experts
(
inp
,
fwd_expert_count
)
batch_size
=
fwd_expert_count
[
i
].
item
()
outputs
=
[]
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
base_idx
=
0
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
for
i
in
range
(
self
.
num_expert
):
base_idx
+=
batch_size
batch_size
=
fwd_expert_count
[
i
].
item
()
return
torch
.
cat
(
outputs
,
dim
=
0
)
inp_slice
=
inp
[
base_idx
:
base_idx
+
batch_size
]
self
.
expert_fn
=
expert_fn
outputs
.
append
(
self
.
experts
[
i
](
inp_slice
))
base_idx
+=
batch_size
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
mark_parallel_comm
(
self
):
def
mark_parallel_comm
(
self
):
r
'''
r
'''
...
...
fmoe/transformer.py
View file @
87dad9d5
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
...
@@ -14,8 +14,10 @@ class _Expert(nn.Module):
'''
'''
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
rank
)
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
rank
)
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
...
@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
...
@@ -47,11 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
top_k
=
2
,
pre_lnorm
=
False
pre_lnorm
=
False
):
):
def
expert_fn
(
inp
,
gate
):
return
self
.
experts
(
inp
,
gate
)
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
gate
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
,
world_size
=
world_size
,
mp_group
=
mp_group
)
expert_fn
=
expert_fn
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
)
rank
=
self
.
mp_rank
)
self
.
pre_lnorm
=
pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
...
...
tests/moe.py
View file @
87dad9d5
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
...
@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self
.
weight_htoh4
=
nn
.
Parameter
(
self
.
weight_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
,
d_model
)
)
)
self
.
bias_htoh4
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_hidden
)
)
self
.
weight_h4toh
=
nn
.
Parameter
(
self
.
weight_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
,
d_hidden
)
)
)
self
.
bias_h4toh
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
*
world_size
,
d_model
)
)
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
...
@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx
=
(
gate_idx
==
i
)
idx
=
(
gate_idx
==
i
)
x
=
inp
[
idx
]
x
=
inp
[
idx
]
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
+
self
.
bias_htoh4
[
i
]
x
=
self
.
activation
(
x
)
x
=
self
.
activation
(
x
)
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
@
self
.
weight_h4toh
[
i
].
t
()
x
=
x
+
self
.
bias_h4toh
[
i
]
o
[
idx
]
=
x
o
[
idx
]
=
x
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
x
=
torch
.
bmm
(
gate_score
,
o
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
self
.
d_model
)).
reshape
(
-
1
,
self
.
d_model
)
...
...
tests/test_dp.py
View file @
87dad9d5
...
@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
...
@@ -10,6 +10,19 @@ from fmoe.transformer import _Expert
n_devices
=
int
(
os
.
environ
.
get
(
"N_GPUS"
,
"2"
))
n_devices
=
int
(
os
.
environ
.
get
(
"N_GPUS"
,
"2"
))
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
1
,
mp_group
=
None
,
top_k
=
top_k
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -26,22 +39,12 @@ def test_fmoe_dp(
...
@@ -26,22 +39,12 @@ def test_fmoe_dp(
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
).
cuda
()
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
top_k
,
activation
).
cuda
()
def
expert_fn
(
inp
,
gate
):
return
experts
(
inp
,
gate
)
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
1
,
mp_group
=
None
,
expert_fn
=
expert_fn
,
top_k
=
top_k
,
).
cuda
()
moe_dp
=
torch
.
nn
.
DataParallel
(
moe
,
device_ids
=
list
(
range
(
n_devices
)))
moe_dp
=
torch
.
nn
.
DataParallel
(
moe
,
device_ids
=
list
(
range
(
n_devices
)))
for
i
in
range
(
5
):
for
i
in
range
(
5
):
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
output
=
moe_dp
(
torch
.
rand
(
batch_size
,
d_model
).
cuda
())
if
__name__
==
'__main__'
:
test_fmoe_dp
(
4
,
2
,
4
,
16
,
32
)
tests/test_numerical.py
View file @
87dad9d5
...
@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
...
@@ -52,6 +52,20 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
assert
False
assert
False
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -74,20 +88,8 @@ def test_fmoe_linear(
...
@@ -74,20 +88,8 @@ def test_fmoe_linear(
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
).
cuda
()
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
def
expert_fn
(
inp
,
gate
):
return
experts
(
inp
,
gate
)
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert_fn
=
expert_fn
,
top_k
=
top_k
,
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
activation
=
activation
,
...
@@ -99,38 +101,54 @@ def test_fmoe_linear(
...
@@ -99,38 +101,54 @@ def test_fmoe_linear(
).
cuda
()
).
cuda
()
if
world_size
==
1
:
if
world_size
==
1
:
moe_raw
.
weight_htoh4
.
data
=
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_htoh4
.
data
=
moe
.
experts
.
htoh4
.
weight
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
bias_htoh4
.
data
=
moe
.
experts
.
htoh4
.
bias
.
data
.
clone
()
moe_raw
.
weight_h4toh
.
data
=
moe
.
experts
.
h4toh
.
weight
.
data
.
clone
()
moe_raw
.
bias_h4toh
.
data
=
moe
.
experts
.
h4toh
.
bias
.
data
.
clone
()
else
:
else
:
weight_htoh4_array
=
[
weight_htoh4_array
=
[
torch
.
empty_like
(
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
moe
.
experts
.
htoh4
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
bias_htoh4_array
=
[
torch
.
empty_like
(
moe
.
experts
.
htoh4
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
]
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_htoh4_array
,
moe
.
experts
.
htoh4
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_htoh4_array
,
moe
.
experts
.
htoh4
.
bias
.
data
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
weight_htoh4
.
data
=
torch
.
cat
(
weight_htoh4_array
,
dim
=
0
)
moe_raw
.
bias_htoh4
.
data
=
torch
.
cat
(
bias_htoh4_array
,
dim
=
0
)
weight_h4toh_array
=
[
weight_h4toh_array
=
[
torch
.
empty_like
(
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
torch
.
empty_like
(
moe
.
experts
.
h4toh
.
weight
.
data
)
for
_
in
range
(
world_size
)
]
bias_h4toh_array
=
[
torch
.
empty_like
(
moe
.
experts
.
h4toh
.
bias
.
data
)
for
_
in
range
(
world_size
)
]
]
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
weight_h4toh_array
,
moe
.
experts
.
h4toh
.
weight
.
data
)
torch
.
distributed
.
all_gather
(
bias_h4toh_array
,
moe
.
experts
.
h4toh
.
bias
.
data
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
weight_h4toh
.
data
=
torch
.
cat
(
weight_h4toh_array
,
dim
=
0
)
moe_raw
.
bias_h4toh
.
data
=
torch
.
cat
(
bias_h4toh_array
,
dim
=
0
)
moe_out
,
raw_out
=
_perform_forward
(
moe_out
,
raw_out
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
)
moe_out_list
=
moe_out
,
experts
.
htoh4
.
weight
.
grad
,
experts
.
h4toh
.
weight
.
grad
moe_out_list
=
moe_out
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
if
world_size
>
1
:
if
world_size
>
1
:
_
,
htoh4_grad
,
h4toh_grad
=
raw_out_list
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
torch
.
distributed
.
all_reduce
(
htoh4_grad
)
torch
.
distributed
.
all_reduce
(
htoh4_w_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_w_grad
)
torch
.
distributed
.
all_reduce
(
htoh4_b_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_b_grad
)
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
htoh4_grad
=
htoh4_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
htoh4_w_grad
=
htoh4_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_grad
=
h4toh_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_w_grad
=
h4toh_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
raw_out_list
=
_
,
htoh4_grad
,
h4toh_grad
htoh4_b_grad
=
htoh4_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_b_grad
=
h4toh_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
raw_out_list
=
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
]
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight
grad"
,
"htoh4 bias grad"
,
"h4toh bias
grad"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment