Commit 6fd0b406 authored by zihanl's avatar zihanl
Browse files

merge with main branch

parents 492fdf83 60750922
import math
import torch
from torch.nn import LayerNorm
from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func
def test_load_fused_kernels():
try:
import fused_mix_prec_layer_norm_cuda
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
import torch
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
embedding_output = bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
# (bsz, 1, 1, seq_len)
mask = bert.get_extended_attention_mask(
attention_mask=tokens["attention_mask"].cuda(),
input_shape=tokens["input_ids"].shape,
device=bert.device,
)
# (bsz, 1, seq_len, seq_len)
mask = mask.repeat(1, 1, mask.size()[-1], 1)
attention = bert.encoder.layer[0].attention.self
key_layer = attention.transpose_for_scores(attention.key(embedding_output))
query_layer = attention.transpose_for_scores(attention.query(embedding_output))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores /= math.sqrt(key_layer.size()[-1])
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attention_scores,
(mask != 0),
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attention_scores,
(mask != 0),
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_fused_upper_triangle_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
attn = gpt.h[0]
hidden_states = gpt.wte(tokens["input_ids"].cuda())
q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))
sq, sk = q.size(-2), k.size(-2)
causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
# [bsz, seq_len, d_model]
embedding_output = (
bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
.cuda()
.half()
)
fused_layernorm_layer = (
MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
torch_layernorm_layer = (
LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
fused_output = fused_layernorm_layer(embedding_output)
torch_output = torch_layernorm_layer(embedding_output)
test_result = (fused_output - torch_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_layer_norm"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_layer_norm"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
if __name__ == "__main__":
try:
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import transformers
transformers.logging.set_verbosity(
transformers.logging.FATAL,
)
except:
print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1)
test_load_fused_kernels()
test_fused_softmax()
test_fused_upper_triangle_mask_softmax()
test_layer_norm()
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import torch import torch
from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from .arguments import parse_args from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator from .microbatches import build_num_microbatches_calculator
...@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None ...@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
def get_args(): def get_args():
...@@ -75,6 +77,14 @@ def get_timers(): ...@@ -75,6 +77,14 @@ def get_timers():
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def get_signal_handler():
_ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
return _GLOBAL_SIGNAL_HANDLER
def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
def set_global_variables(extra_args_provider=None, args_defaults={}, def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
...@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, ...@@ -89,6 +99,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
if args.exit_signal_handler:
_set_signal_handler()
def _parse_args(extra_args_provider=None, defaults={}, def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from datetime import timedelta
from megatron import fused_kernels from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
...@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
args.use_cpu_initialization=True args.use_cpu_initialization=True
...@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init() finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
...@@ -175,15 +176,11 @@ def _initialize_distributed(): ...@@ -175,15 +176,11 @@ def _initialize_distributed():
else: else:
args.local_rank = device args.local_rank = device
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Call the init process # Call the init process
init_method = 'tcp://' torch.distributed.init_process_group(
master_ip = os.getenv('MASTER_ADDR', 'localhost') backend=args.distributed_backend,
master_port = os.getenv('MASTER_PORT', '6000') world_size=args.world_size, rank=args.rank,
init_method += master_ip + ':' + master_port timeout=timedelta(minutes=10))
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
...@@ -193,7 +190,8 @@ def _initialize_distributed(): ...@@ -193,7 +190,8 @@ def _initialize_distributed():
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size) args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
def _init_autoresume(): def _init_autoresume():
...@@ -229,10 +227,24 @@ def write_args_to_tensorboard(): ...@@ -229,10 +227,24 @@ def write_args_to_tensorboard():
global_step=args.iteration) global_step=args.iteration)
def _initialize_mem_buffs(): def _set_jit_fusion_options():
"""Initialize manually allocated static memory.""" """Set PyTorch JIT layer fusion options."""
args = get_args() # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel ...@@ -21,3 +21,4 @@ from .gpt_model import GPTModel
from .t5_model import T5Model from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
from .enums import ModelType
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
import enum import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
import torch import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
###### BIAS GELU FUSION/ NO AUTOGRAD ################ ###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
......
...@@ -23,6 +23,12 @@ from torch.nn.parameter import Parameter ...@@ -23,6 +23,12 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except:
HAVE_PERSIST_LAYER_NORM = False
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None fused_mix_prec_layer_norm_cuda = None
...@@ -61,13 +67,23 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -61,13 +67,23 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5): def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module( fused_mix_prec_layer_norm_cuda = importlib.import_module(
"fused_mix_prec_layer_norm_cuda") "fused_mix_prec_layer_norm_cuda")
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536]
if normalized_shape not in persist_ln_hidden_sizes or \
not HAVE_PERSIST_LAYER_NORM:
no_persist_layer_norm = True
if isinstance(normalized_shape, numbers.Integral): if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
...@@ -75,6 +91,7 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -75,6 +91,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.weight = Parameter(torch.Tensor(*normalized_shape)) self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters() self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
def reset_parameters(self): def reset_parameters(self):
...@@ -85,6 +102,10 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -85,6 +102,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def forward(self, input): def forward(self, input):
return FusedLayerNormAffineFunction.apply( if self.no_persist_layer_norm:
input, self.weight, self.bias, self.normalized_shape,self.eps) return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps)
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
...@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0] inputs, scale_t[0]
) )
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward( input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0] output_grads, softmax_results, scale_t[0]
) )
return input_grads, None return input_grads, None
...@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward( softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal) attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied. mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision. softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling. scale: scaling factor used in input tensor scaling.
""" """
def __init__( def __init__(
...@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 and self.input_in_bf16),\ assert not (
'both fp16 and bf16 flags cannot be active at the same time.' self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
...@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert ( assert (
self.scale is None or softmax_in_fp32 self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
assert input.dim() == 4 assert input.dim() == 4
data_size = input.size()
query_seq_len = data_size[-2] if self.is_kernel_available(mask, *input.size()):
key_seq_len = data_size[-1] return self.forward_fused_softmax(input, mask)
attn_batch_size = data_size[0] * data_size[1]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel
if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_float16 and self.softmax_in_fp32: return self.forward_torch_softmax(input, mask)
input = input.float()
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
if self.scale is not None: def forward_fused_softmax(self, input, mask):
input = input * self.scale b, np, sq, sk = input.size()
mask_output = self.mask_func(input, mask) if mask is not None else input scale = self.scale if self.scale is not None else 1.0
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32: if self.attn_mask_type == AttnMaskType.causal:
if self.input_in_fp16: assert sq == sk, "causal mask is only for self attention"
probs = probs.half()
else: # input is 3D tensor (attn_batches, sq, sk)
probs = probs.bfloat16() input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal ...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
if get_key_value:
lm_output, presents = lm_output
# Output. # Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if get_key_value:
output = [output, presents]
if labels is None: if labels is None:
return output return output
else: else:
...@@ -90,23 +82,19 @@ class GPTModel(MegatronModule): ...@@ -90,23 +82,19 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, inference_params=None):
forward_method_parallel_output=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
layer_past=layer_past, inference_params=inference_params)
get_key_value=get_key_value)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
get_key_value,
self.parallel_output, self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
return lm_output return lm_output
......
...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -45,7 +45,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True): pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler, ...@@ -64,6 +65,7 @@ def get_language_model(num_tokentypes, add_pooler,
scaled_init_method, scaled_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
add_decoder=add_decoder, add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type, decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler, add_pooler=add_pooler,
...@@ -159,6 +161,16 @@ class Embedding(MegatronModule): ...@@ -159,6 +161,16 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -273,6 +285,7 @@ class TransformerLanguageModel(MegatronModule):
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_encoder=True,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False, add_pooler=False,
...@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule): ...@@ -286,10 +299,12 @@ class TransformerLanguageModel(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
...@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule): ...@@ -302,25 +317,37 @@ class TransformerLanguageModel(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.encoder = ParallelTransformer( # Encoder (usually set to True, False if part of an encoder-decoder
self.init_method, # architecture and in encoder-only stage).
output_layer_init_method, if self.add_encoder:
self_attn_mask_type=self.encoder_attn_mask_type, self.encoder = ParallelTransformer(
pre_process=self.pre_process, self.init_method,
post_process=self.post_process output_layer_init_method,
) self_attn_mask_type=self.encoder_attn_mask_type,
self._encoder_key = 'encoder' pre_process=self.pre_process,
post_process=self.post_process
# Decoder )
self._encoder_key = 'encoder'
else:
self.encoder = None
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert args.pipeline_model_parallel_size == 1, \ assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder' 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
else:
self.decoder = None
if self.post_process: if self.post_process:
# Pooler. # Pooler.
...@@ -330,28 +357,55 @@ class TransformerLanguageModel(MegatronModule): ...@@ -330,28 +357,55 @@ class TransformerLanguageModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_encoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with only encoder'
self.encoder.set_input_tensor(input_tensor[0])
elif self.add_decoder:
if len(input_tensor) == 2:
self.decoder.set_input_tensor(input_tensor[0])
self.encoder_hidden_state = input_tensor[1]
elif len(input_tensor) == 1:
self.decoder.set_input_tensor(None)
self.encoder_hidden_state = input_tensor[0]
else:
raise Exception('input_tensor must have either length 1 or 2')
else:
raise Exception('Stage must have at least either encoder or decoder')
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None,
get_key_value=False, pooling_sequence_index=0, inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Encoder embedding.
if self.pre_process: if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
encoder_input = embedding_output
else: else:
encoder_input = None encoder_input = None
# encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input, if self.encoder is not None:
enc_attn_mask, encoder_output = self.encoder(
layer_past=layer_past, encoder_input,
get_key_value=get_key_value) enc_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder_hidden_state
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
...@@ -369,16 +423,20 @@ class TransformerLanguageModel(MegatronModule): ...@@ -369,16 +423,20 @@ class TransformerLanguageModel(MegatronModule):
else: else:
return encoder_output return encoder_output
# Decoder Embedding # Decoder embedding.
dec_embedding_output = self.embedding(dec_input_ids, if self.pre_process:
dec_position_ids) decoder_input = self.embedding(dec_input_ids,
# decoder dec_position_ids)
decoder_output = self.decoder(dec_embedding_output, else:
dec_attn_mask, decoder_input = None
layer_past=layer_past,
get_key_value=get_key_value, # Run decoder.
encoder_output=encoder_output, decoder_output = self.decoder(
enc_dec_attn_mask=enc_dec_attn_mask) decoder_input,
dec_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
...@@ -394,9 +452,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -394,9 +452,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._encoder_key] \ if self.add_encoder:
= self.encoder.state_dict_for_save_checkpoint( state_dict_[self._encoder_key] \
destination, prefix, keep_vars) = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
...@@ -425,38 +484,39 @@ class TransformerLanguageModel(MegatronModule): ...@@ -425,38 +484,39 @@ class TransformerLanguageModel(MegatronModule):
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Encoder. # Encoder.
if self._encoder_key in state_dict: if self.add_encoder:
state_dict_ = state_dict[self._encoder_key] if self._encoder_key in state_dict:
# for backward compatibility. state_dict_ = state_dict[self._encoder_key]
elif 'transformer' in state_dict: # For backward compatibility.
state_dict_ = state_dict['transformer'] elif 'transformer' in state_dict:
else: state_dict_ = state_dict['transformer']
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else: else:
state_dict_self_attention[key] = state_dict_[key] # For backward compatibility.
state_dict_ = state_dict_self_attention state_dict_ = {}
for key in state_dict.keys():
self.encoder.load_state_dict(state_dict_, strict=strict) if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# For backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.post_process: if self.post_process:
# pooler
if self.add_pooler: if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
# decoder # Decoder.
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
......
...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module): ...@@ -51,15 +51,14 @@ class MegatronModule(torch.nn.Module):
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(ignore_virtual=True): if not mpu.is_pipeline_last_stage(ignore_virtual=True) or \
mpu.get_pipeline_model_parallel_world_size() == 1:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(ignore_virtual=True): else:
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last ' raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module): ...@@ -69,12 +68,12 @@ class MegatronModule(torch.nn.Module):
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage # This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline # when we are using pipeline parallelism. Nothing to do if we aren't
# parallelism there is nothing to do. # using pipeline parallelism.
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
return return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
# workers, so we do the following: # workers, so we do the following:
...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module): ...@@ -97,12 +96,34 @@ class MegatronModule(torch.nn.Module):
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
not mpu.is_pipeline_last_stage(ignore_virtual=True) and \
mpu.is_rank_in_embedding_group():
self.language_model.embedding.zero_parameters()
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
dimensions = (args.max_position_embeddings, args.hidden_size)
if mpu.is_pipeline_last_stage(ignore_virtual=True):
position_embeddings = torch.nn.Embedding(*dimensions).cuda()
position_embeddings.weight.data.fill_(0)
else:
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group())
else: else:
print("WARNING! Distributed processes aren't initialized, so " print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. " "word embeddings in the last layer are not initialized. "
...@@ -166,6 +187,10 @@ class Float16Module(MegatronModule): ...@@ -166,6 +187,10 @@ class Float16Module(MegatronModule):
self.float16_convertor = float16_convertor self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor) inputs = fp32_to_float16(inputs, self.float16_convertor)
......
...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule): ...@@ -86,7 +86,13 @@ class T5LMHead(MegatronModule):
class T5Model(MegatronModule): class T5Model(MegatronModule):
"""T5 Language model.""" """T5 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(T5Model, self).__init__() super(T5Model, self).__init__()
args = get_args() args = get_args()
...@@ -95,19 +101,29 @@ class T5Model(MegatronModule): ...@@ -95,19 +101,29 @@ class T5Model(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
add_decoder=True, add_encoder=add_encoder,
add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.lm_head = T5LMHead( self.initialize_word_embeddings(init_method_normal)
self.language_model.embedding.word_embeddings.weight.size(0),
parallel_output) if self.post_process and self.add_decoder:
self._lm_head_key = 'lm_head' self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0),
parallel_output)
self._lm_head_key = 'lm_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
...@@ -134,22 +150,28 @@ class T5Model(MegatronModule): ...@@ -134,22 +150,28 @@ class T5Model(MegatronModule):
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
enc_hidden_states=enc_hidden_states) enc_hidden_states=enc_hidden_states)
decoder_output, encoder_output = lm_output if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output
# Output. # Output.
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.language_model.embedding.word_embeddings.weight) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits, encoder_output return lm_logits
else:
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), if self.fp16_lm_cross_entropy:
lm_labels) assert lm_logits.dtype == torch.half
return lm_loss, encoder_output lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss
elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output
return decoder_output
else:
encoder_output = lm_output
return encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -160,9 +182,14 @@ class T5Model(MegatronModule): ...@@ -160,9 +182,14 @@ class T5Model(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \ if self.post_process and self.add_decoder:
= self.lm_head.state_dict_for_save_checkpoint( state_dict_[self._lm_head_key] \
destination, prefix, keep_vars) = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -170,5 +197,10 @@ class T5Model(MegatronModule): ...@@ -170,5 +197,10 @@ class T5Model(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key], if self.post_process and self.add_decoder:
strict=strict) self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
# Load word embeddings.
if self.post_process and not self.pre_process and self.add_decoder:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
...@@ -21,17 +21,12 @@ import torch.nn.functional as F ...@@ -21,17 +21,12 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -53,8 +48,7 @@ class ParallelMLP(MegatronModule): ...@@ -53,8 +48,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also state back into h hidden dimension.
applied.
""" """
def __init__(self, init_method, output_layer_init_method): def __init__(self, init_method, output_layer_init_method):
...@@ -84,7 +78,6 @@ class ParallelMLP(MegatronModule): ...@@ -84,7 +78,6 @@ class ParallelMLP(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states): def forward(self, hidden_states):
# [s, b, 4hp] # [s, b, 4hp]
...@@ -125,6 +118,7 @@ class ParallelAttention(MegatronModule): ...@@ -125,6 +118,7 @@ class ParallelAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -185,10 +179,40 @@ class ParallelAttention(MegatronModule): ...@@ -185,10 +179,40 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None): def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
...@@ -229,18 +253,28 @@ class ParallelAttention(MegatronModule): ...@@ -229,18 +253,28 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
# ================================== # ==================================
if layer_past is not None: if inference_params:
past_key, past_value = layer_past batch_start = inference_params.batch_size_offset
key_layer = torch.cat((past_key.type_as(key_layer), batch_end = batch_start + key_layer.size(1)
key_layer), dim=0) assert batch_end <= inference_key_memory.size(1)
value_layer = torch.cat((past_value.type_as(value_layer), sequence_start = inference_params.sequence_len_offset
value_layer), dim=0) sequence_end = sequence_start + key_layer.size(0)
if get_key_value: assert sequence_end <= inference_key_memory.size(0)
present = (key_layer, value_layer) # Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
...@@ -277,22 +311,6 @@ class ParallelAttention(MegatronModule): ...@@ -277,22 +311,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
...@@ -348,9 +366,6 @@ class ParallelAttention(MegatronModule): ...@@ -348,9 +366,6 @@ class ParallelAttention(MegatronModule):
output, bias = self.dense(context_layer) output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias return output, bias
...@@ -368,14 +383,18 @@ def get_bias_dropout_add(training): ...@@ -368,14 +383,18 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob): def bias_dropout_add_fused_train(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob): def bias_dropout_add_fused_inference(x: torch.Tensor,
# type: (Tensor, Tensor, Tensor, float) -> Tensor bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -404,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -404,7 +423,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -419,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -419,7 +439,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -430,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -430,7 +451,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output. # Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm( self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP # MLP
self.mlp = ParallelMLP(init_method, self.mlp = ParallelMLP(init_method,
...@@ -438,20 +460,17 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -438,20 +460,17 @@ class ParallelTransformerLayer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False): inference_params=None):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(
attention_mask, layernorm_output,
layer_past=layer_past, attention_mask,
get_key_value=get_key_value) inference_params=inference_params)
if get_key_value:
attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -521,9 +540,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -521,9 +540,6 @@ class ParallelTransformerLayer(MegatronModule):
residual, residual,
self.hidden_dropout) self.hidden_dropout)
if get_key_value:
output = [output, presents]
return output return output
...@@ -544,13 +560,13 @@ class ParallelTransformer(MegatronModule): ...@@ -544,13 +560,13 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = None self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.activations_checkpoint_method = args.activations_checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers. # Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \ self.num_layers = mpu.get_num_layers(
'num_layers must be divisible by pipeline_model_parallel_size' args, args.model_type == ModelType.encoder_and_decoder)
self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
...@@ -589,7 +605,8 @@ class ParallelTransformer(MegatronModule): ...@@ -589,7 +605,8 @@ class ParallelTransformer(MegatronModule):
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
...@@ -609,14 +626,32 @@ class ParallelTransformer(MegatronModule): ...@@ -609,14 +626,32 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
# Make sure memory is freed. if self.activations_checkpoint_method == 'uniform':
mpu.reset_checkpointed_activations_memory_buffer() # Uniformly divide the total number of Transformer layers and checkpoint
l = 0 # the input activation of each divided chunk.
while l < self.num_layers: # A method to further reduce memory usage reducing checkpoints.
hidden_states = mpu.checkpoint( l = 0
custom(l, l + self.checkpoint_num_layers), while l < self.num_layers:
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states = mpu.checkpoint(
l += self.checkpoint_num_layers custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states return hidden_states
...@@ -630,18 +665,14 @@ class ParallelTransformer(MegatronModule): ...@@ -630,18 +665,14 @@ class ParallelTransformer(MegatronModule):
forward_step_func""" forward_step_func"""
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks. # Checks.
if layer_past is not None: if inference_params:
assert get_key_value, \ assert self.activations_checkpoint_method is None, \
'for not None values in layer_past, ' \ 'inference does not work with activation checkpointing'
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
'get_key_value does not work with ' \
'activation checkpointing'
if self.pre_process: if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...@@ -658,28 +689,21 @@ class ParallelTransformer(MegatronModule): ...@@ -658,28 +689,21 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
enc_dec_attn_mask) enc_dec_attn_mask)
else: else:
if get_key_value:
presents = []
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
past = None hidden_states = layer(
if layer_past is not None: hidden_states,
past = layer_past[index] attention_mask,
hidden_states = layer(hidden_states, encoder_output=encoder_output,
attention_mask, enc_dec_attn_mask=enc_dec_attn_mask,
encoder_output=encoder_output, inference_params=inference_params)
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
...@@ -688,7 +712,5 @@ class ParallelTransformer(MegatronModule): ...@@ -688,7 +712,5 @@ class ParallelTransformer(MegatronModule):
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
if get_key_value:
output = [output, presents]
return output return output
...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group ...@@ -31,6 +31,10 @@ from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank from .initialize import get_pipeline_model_parallel_last_rank
...@@ -56,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region ...@@ -56,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks from .random import split_tensor_into_1d_equal_chunks
......
...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None ...@@ -34,6 +34,7 @@ _DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -41,8 +42,11 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source # A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
...@@ -52,13 +56,19 @@ def is_unitialized(): ...@@ -52,13 +56,19 @@ def is_unitialized():
def initialize_model_parallel(tensor_model_parallel_size_=1, def initialize_model_parallel(tensor_model_parallel_size_=1,
pipeline_model_parallel_size_=1, pipeline_model_parallel_size_=1,
virtual_pipeline_model_parallel_size_=None): virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=None):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used to parallelize model tensor. tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -101,6 +111,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_ _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
# Build the data-parallel groups. # Build the data-parallel groups.
...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -148,6 +162,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, \ assert _EMBEDDING_GROUP is None, \
'embedding group is already initialized' 'embedding group is already initialized'
for i in range(num_pipeline_model_parallel_groups): for i in range(num_pipeline_model_parallel_groups):
...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -161,11 +176,18 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
if pipeline_model_parallel_split_rank_ is not None and \
pipeline_model_parallel_split_rank_ not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks: if rank in embedding_ranks:
_EMBEDDING_GROUP = group _EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
def model_parallel_is_initialized(): def model_parallel_is_initialized():
...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank(): ...@@ -268,6 +290,30 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_num_layers(args, is_encoder_and_decoder_model):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \
'num_layers must be divisible by number of ranks given to encoder'
assert args.num_layers % num_ranks_in_decoder == 0, \
'num_layers must be divisible by number of ranks given to decoder'
if is_pipeline_stage_before_split():
num_layers = args.num_layers // num_ranks_in_encoder
else:
num_layers = args.num_layers // num_ranks_in_decoder
else:
assert args.num_layers % get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers = args.num_layers // get_pipeline_model_parallel_world_size()
else:
num_layers = args.num_layers
return num_layers
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual: if not ignore_virtual:
...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False): ...@@ -290,6 +336,61 @@ def is_pipeline_last_stage(ignore_virtual=False):
get_pipeline_model_parallel_world_size() - 1) get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank(): def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank.""" """Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
...@@ -356,9 +457,13 @@ def get_data_parallel_rank(): ...@@ -356,9 +457,13 @@ def get_data_parallel_rank():
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
...@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter ...@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
...@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -256,7 +288,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -256,7 +288,7 @@ class ColumnParallelLinear(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype)) device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method, _initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride) partition_dim=0, stride=stride)
if bias: if bias:
if args.use_cpu_initialization: if args.use_cpu_initialization:
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
...@@ -272,21 +304,35 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -272,21 +304,35 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and
world_size > 1)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
...@@ -316,8 +362,8 @@ class RowParallelLinear(torch.nn.Module): ...@@ -316,8 +362,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. We skip
adding bias but instead return it. adding bias but instead return it.
""" """
......
...@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim ...@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
def _reduce(input_): def _reduce(input_):
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1: if get_tensor_model_parallel_world_size()==1:
......
...@@ -24,7 +24,6 @@ from torch import _C ...@@ -24,7 +24,6 @@ from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
from megatron import get_args
from megatron.memory import allocate_mem_buff from megatron.memory import allocate_mem_buff
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
...@@ -37,37 +36,6 @@ from .initialize import get_tensor_model_parallel_world_size ...@@ -37,37 +36,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
...@@ -101,14 +69,21 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -101,14 +69,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) partition_size = torch.numel(tensor) // \
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank() start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
...@@ -250,8 +225,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -250,8 +225,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset. tracked/set/reset.
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, *args): def forward(ctx, run_function, distribute_checkpointed_activations, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -263,16 +240,14 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -263,16 +240,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( new_buffer=True)
args[0].data)
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
return outputs return outputs
@staticmethod @staticmethod
...@@ -281,7 +256,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -281,7 +256,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape) inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
...@@ -310,10 +285,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -310,10 +285,11 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs) for inp in detached_inputs)
return (None,) + grads return (None, None) + grads
def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint.""" This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
...@@ -100,10 +100,12 @@ def get_megatron_optimizer(model): ...@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args.clip_grad, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16, args.bf16,
grad_scaler) grad_scaler)
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad) params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment