Commit 9c5e287d authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'nextlm-fixes' into 'main'

Some quick fixes.

See merge request ADLR/megatron-lm!563
parents cdd2afdf 45084d46
...@@ -13,3 +13,6 @@ class AttnType(enum.Enum): ...@@ -13,3 +13,6 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum): class AttnMaskType(enum.Enum):
padding = 1 padding = 1
causal = 2 causal = 2
# For backward compatibility with old model checkpoints
from megatron.core.enums import ModelType
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional
from megatron import get_timers, get_args, core, get_num_microbatches from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule from .module import MegatronModule
...@@ -673,7 +674,7 @@ class ParallelAttention(MegatronModule): ...@@ -673,7 +674,7 @@ class ParallelAttention(MegatronModule):
def bias_dropout_add(x, bias, residual, prob, training): def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
if bias is not None: if bias is not None:
x = x + bias x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training) out = torch.nn.functional.dropout(x, p=prob, training=training)
...@@ -689,7 +690,7 @@ def get_bias_dropout_add(training): ...@@ -689,7 +690,7 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor, def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor, bias: Optional[torch.Tensor],
residual: torch.Tensor, residual: torch.Tensor,
prob: float) -> torch.Tensor: prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
...@@ -697,7 +698,7 @@ def bias_dropout_add_fused_train(x: torch.Tensor, ...@@ -697,7 +698,7 @@ def bias_dropout_add_fused_train(x: torch.Tensor,
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor, def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor, bias: Optional[torch.Tensor],
residual: torch.Tensor, residual: torch.Tensor,
prob: float) -> torch.Tensor: prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment