Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a8ecd3d7
"src/vscode:/vscode.git/clone" did not exist on "80c00e5451e0ced32043fbb0ed06eb6f3c427f82"
Commit
a8ecd3d7
authored
Feb 02, 2021
by
Rick Ho
Browse files
remove debug output and todo for replicated mp input
parent
01ae2d72
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
5 deletions
+1
-5
fmoe/layers.py
fmoe/layers.py
+0
-5
tests/test.sh
tests/test.sh
+1
-0
No files found.
fmoe/layers.py
View file @
a8ecd3d7
...
...
@@ -24,7 +24,6 @@ class FMoELinear(nn.Module):
class
FMoENaiveGate
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
(
FMoENaiveGate
,
self
).
__init__
()
# print(f"gate: {num_expert * world_size}")
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
top_k
=
top_k
...
...
@@ -92,7 +91,6 @@ class FMoETransformerMLP(nn.Module):
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
# print(f"FMoETransformerMLP world_size: {world_size} num_expert: {num_expert}")
self
.
gate
=
FMoENaiveGate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
...
...
@@ -107,8 +105,6 @@ class FMoETransformerMLP(nn.Module):
batch_start
=
local_batch_size
*
self
.
model_parallel_rank
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
inp
[:,
batch_start
:
batch_end
,
:].
contiguous
()
# print(inp.shape)
# print(f"mp_rank: {self.model_parallel_rank}, [{batch_start}, {batch_end})")
residual
=
inp
if
self
.
pre_lnorm
:
...
...
@@ -116,7 +112,6 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# TODO: merge replication into local_scatter
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
...
...
tests/test.sh
View file @
a8ecd3d7
#!/bin/bash
if
[
!
-z
$OMPI_COMM_WORLD_LOCAL_RANK
]
then
export
CUDA_VISIBLE_DEVICES
=
$OMPI_COMM_WORLD_LOCAL_RANK
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment