Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
969ef607
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a075d62971aa140ec72626f25033cf905715530a"
Commit
969ef607
authored
Jan 13, 2021
by
Jiezhong Qiu
Browse files
merge topk and only forward moe once
parent
b9084e90
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
3 deletions
+13
-3
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+13
-3
No files found.
pytorch/mem_transformer.py
View file @
969ef607
...
@@ -70,14 +70,23 @@ class CustomizedMoEPositionwiseFF(nn.Module):
...
@@ -70,14 +70,23 @@ class CustomizedMoEPositionwiseFF(nn.Module):
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
# gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
# (BxL) x 1 x top_k
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
,
self
.
top_k
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLx
top_k)
core_out
=
[]
#
core_out = []
inp
=
inp
.
view
(
-
1
,
self
.
d_model
)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
)
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
inp
=
F
.
pad
(
inp
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
inp
=
F
.
pad
(
inp
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
x
=
self
.
moe1
(
inp
,
gate_top_k_idx
)
x
=
self
.
dropout
(
F
.
relu
(
x
))
x
=
F
.
pad
(
x
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
x
=
self
.
moe2
(
x
,
gate_top_k_idx
)
x
=
self
.
dropout
(
x
)
# (BxLxtop_k) x d_model
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
"""
for i in range(self.top_k):
for i in range(self.top_k):
gate_idx = gate_top_k_idx[:, i].contiguous()
gate_idx = gate_top_k_idx[:, i].contiguous()
x = self.moe1(inp, gate_idx)
x = self.moe1(inp, gate_idx)
...
@@ -88,6 +97,7 @@ class CustomizedMoEPositionwiseFF(nn.Module):
...
@@ -88,6 +97,7 @@ class CustomizedMoEPositionwiseFF(nn.Module):
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
"""
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment