Commit 9c5e287d authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'nextlm-fixes' into 'main'

Some quick fixes.

See merge request ADLR/megatron-lm!563
parents cdd2afdf 45084d46
......@@ -13,3 +13,6 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
# For backward compatibility with old model checkpoints
from megatron.core.enums import ModelType
......@@ -5,6 +5,7 @@ import math
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from typing import Optional
from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
......@@ -673,7 +674,7 @@ class ParallelAttention(MegatronModule):
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
......@@ -689,7 +690,7 @@ def get_bias_dropout_add(training):
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
bias: Optional[torch.Tensor],
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
......@@ -697,7 +698,7 @@ def bias_dropout_add_fused_train(x: torch.Tensor,
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
bias: Optional[torch.Tensor],
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment