Unverified Commit 002aae63 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] nn.moe: softmax should be done in FP32 (#668)

Co-authored-by: @myleott
parent 29d81c43
......@@ -36,7 +36,8 @@ def one_hot(tensor: torch.Tensor, num_classes: int) -> Tensor:
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=1)
# NOTE(msb) softmax requires FP32: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/
gates = F.softmax(logits, dim=1, dtype=torch.float)
# gates has shape of SE
num_tokens = gates.shape[0]
......@@ -95,7 +96,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask
return l_aux.to(logits.dtype), combine_weights.to(logits.dtype), dispatch_mask
class Top2Gate(torch.nn.Module):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .. import Tensor, _size
from .. import Tensor, _size, _dtype
from typing import Any, Optional, Tuple, Dict, List, Callable, Union
from .common_types import _ratio_any_t
......@@ -154,7 +154,7 @@ def softsign(input: Any): ...
def softmin(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[int] = ...) -> Tensor: ...
def softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[int] = ...) -> Tensor: ...
def softmax(input: Tensor, dim: Optional[int] = ..., _stacklevel: int = ..., dtype: Optional[_dtype] = ...) -> Tensor: ...
def gumbel_softmax(logits: Tensor, tau: float = ..., hard: bool = ..., eps: float = ..., dim: int = ...) -> Tensor: ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment