Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
002aae63
Unverified
Commit
002aae63
authored
May 08, 2021
by
msbaines
Committed by
GitHub
May 08, 2021
Browse files
[fix] nn.moe: softmax should be done in FP32 (#668)
Co-authored-by: @myleott
parent
29d81c43
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
4 deletions
+5
-4
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+3
-2
stubs/torch/nn/functional.pyi
stubs/torch/nn/functional.pyi
+2
-2
No files found.
fairscale/nn/moe/top2gate.py
View file @
002aae63
...
@@ -36,7 +36,8 @@ def one_hot(tensor: torch.Tensor, num_classes: int) -> Tensor:
...
@@ -36,7 +36,8 @@ def one_hot(tensor: torch.Tensor, num_classes: int) -> Tensor:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
"""Implements Top2Gating on logits."""
"""Implements Top2Gating on logits."""
gates
=
F
.
softmax
(
logits
,
dim
=
1
)
# NOTE(msb) softmax requires FP32: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/
gates
=
F
.
softmax
(
logits
,
dim
=
1
,
dtype
=
torch
.
float
)
# gates has shape of SE
# gates has shape of SE
num_tokens
=
gates
.
shape
[
0
]
num_tokens
=
gates
.
shape
[
0
]
...
@@ -95,7 +96,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -95,7 +96,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
combine_weights
=
combine1_sec
+
combine2_sec
combine_weights
=
combine1_sec
+
combine2_sec
dispatch_mask
=
combine_weights
.
bool
()
dispatch_mask
=
combine_weights
.
bool
()
return
l_aux
,
combine_weights
,
dispatch_mask
return
l_aux
.
to
(
logits
.
dtype
),
combine_weights
.
to
(
logits
.
dtype
)
,
dispatch_mask
class
Top2Gate
(
torch
.
nn
.
Module
):
class
Top2Gate
(
torch
.
nn
.
Module
):
...
...
stubs/torch/nn/functional.pyi
View file @
002aae63
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .. import Tensor, _size
from .. import Tensor, _size
, _dtype
from typing import Any, Optional, Tuple, Dict, List, Callable, Union
from typing import Any, Optional, Tuple, Dict, List, Callable, Union
from .common_types import _ratio_any_t
from .common_types import _ratio_any_t
...
@@ -154,7 +154,7 @@ def softsign(input: Any): ...
...
@@ -154,7 +154,7 @@ def softsign(input: Any): ...
def softmin(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[int] = ...) -> Tensor: ...
def softmin(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[int] = ...) -> Tensor: ...
def softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[
int
] = ...) -> Tensor: ...
def softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[
_dtype
] = ...) -> Tensor: ...
def gumbel_softmax(logits: Tensor, tau: float = ..., hard: bool = ..., eps: float = ..., dim: int = ...) -> Tensor: ...
def gumbel_softmax(logits: Tensor, tau: float = ..., hard: bool = ..., eps: float = ..., dim: int = ...) -> Tensor: ...
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment