Commit cb00a196 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into t5_pipeline_parallelism

parents 38a774e9 5ab64637
...@@ -97,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -97,6 +97,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.model_parallel_size is None, '--model-parallel-size is no ' \ assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations:
args.activations_checkpoint_method = 'uniform'
if args.rank == 0:
print('--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.')
del args.checkpoint_activations
# Set input defaults. # Set input defaults.
for key in defaults: for key in defaults:
...@@ -154,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -154,16 +161,15 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have # If we do accumulation and all-reduces in fp32, we need to have local DDP
# local DDP and we should set the use-contiguous-buffers-in-ddp. # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True assert args.use_contiguous_buffers_in_local_ddp
# If we use a contiguous buffer to hold main grads, we need to have # For torch DDP, we do not use contiguous buffer
# local DDP. if args.DDP_impl == 'torch':
if args.use_contiguous_buffers_in_ddp: args.use_contiguous_buffers_in_local_ddp = False
assert args.DDP_impl == 'local'
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -240,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -240,9 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'checkpointed activations only across tensor model ' \
'parallel groups'
assert args.activations_checkpoint_method is not None, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to use a activation-checkpoint method '
assert args.num_layers_per_virtual_pipeline_stage is None, \
'currently distrobuted checkpoint activations only supported for ' \
'nointerleaved pipeline parallelism'
_print_args(args) _print_args(args)
return args return args
...@@ -408,8 +420,20 @@ def _add_training_args(parser): ...@@ -408,8 +420,20 @@ def _add_training_args(parser):
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
'across model parallel group.') 'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1, group.add_argument('--activations-checkpoint-method', type=str, default=None,
help='chunk size (number of layers) for checkpointing.') choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers')
group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs. Note that either train-iters or ' 'training runs. Note that either train-iters or '
...@@ -444,6 +468,11 @@ def _add_training_args(parser): ...@@ -444,6 +468,11 @@ def _add_training_args(parser):
group.add_argument('--dataloader-type', type=str, default=None, group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'], choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_true',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.')
return parser return parser
...@@ -593,9 +622,10 @@ def _add_distributed_args(parser): ...@@ -593,9 +622,10 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true', group.add_argument('--no-contiguous-buffers-in-local-ddp',
help='If set, use contiguous buffer in DDP. Note that ' action='store_false', help='If set, dont use '
'this option only works woth local DDP.' ) 'contiguous buffer in local DDP.',
dest='use_contiguous_buffers_in_local_ddp')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
......
...@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( ...@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results, torch::Tensor const& softmax_results,
float scale_factor); float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd( torch::Tensor fwd(
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& mask, torch::Tensor const& mask,
...@@ -63,6 +69,14 @@ torch::Tensor bwd( ...@@ -63,6 +69,14 @@ torch::Tensor bwd(
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax } // end namespace scaled_masked_softmax
} // end namespace fused_softmax } // end namespace fused_softmax
} // end namespace multihead_attn } // end namespace multihead_attn
...@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward."); "Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
} }
...@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
} }
} }
} }
} // end of anonymous namespace } // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t> template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward( void dispatch_scaled_masked_softmax_forward(
output_t *dst, output_t *dst,
......
...@@ -28,6 +28,11 @@ namespace multihead_attn { ...@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
namespace scaled_masked_softmax { namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda( torch::Tensor fwd_cuda(
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& mask, torch::Tensor const& mask,
......
...@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block; int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1); dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
...@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block; int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1); dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
......
import math
import torch
from torch.nn import LayerNorm
from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func
def test_load_fused_kernels():
try:
import fused_mix_prec_layer_norm_cuda
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
import torch
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
embedding_output = bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
# (bsz, 1, 1, seq_len)
mask = bert.get_extended_attention_mask(
attention_mask=tokens["attention_mask"].cuda(),
input_shape=tokens["input_ids"].shape,
device=bert.device,
)
# (bsz, 1, seq_len, seq_len)
mask = mask.repeat(1, 1, mask.size()[-1], 1)
attention = bert.encoder.layer[0].attention.self
key_layer = attention.transpose_for_scores(attention.key(embedding_output))
query_layer = attention.transpose_for_scores(attention.query(embedding_output))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores /= math.sqrt(key_layer.size()[-1])
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attention_scores,
(mask != 0),
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attention_scores,
(mask != 0),
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_fused_upper_triangle_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
attn = gpt.h[0]
hidden_states = gpt.wte(tokens["input_ids"].cuda())
q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))
sq, sk = q.size(-2), k.size(-2)
causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)
fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)
torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)
torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)
test_result = (fused_softmax_output - torch_softmax_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)
tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)
# [bsz, seq_len, d_model]
embedding_output = (
bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
.cuda()
.half()
)
fused_layernorm_layer = (
MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
torch_layernorm_layer = (
LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)
fused_output = fused_layernorm_layer(embedding_output)
torch_output = torch_layernorm_layer(embedding_output)
test_result = (fused_output - torch_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_layer_norm"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_layer_norm"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
if __name__ == "__main__":
try:
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import transformers
transformers.logging.set_verbosity(
transformers.logging.FATAL,
)
except:
print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1)
test_load_fused_kernels()
test_fused_softmax()
test_fused_upper_triangle_mask_softmax()
test_layer_norm()
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from datetime import timedelta
from megatron import fused_kernels from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
...@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -63,6 +64,9 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
args.use_cpu_initialization=True args.use_cpu_initialization=True
...@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -77,9 +81,6 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init() finish_mpu_init()
# Initialize memory buffers.
_initialize_mem_buffs()
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
...@@ -175,11 +176,11 @@ def _initialize_distributed(): ...@@ -175,11 +176,11 @@ def _initialize_distributed():
else: else:
args.local_rank = device args.local_rank = device
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Call the init process # Call the init process
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank) world_size=args.world_size, rank=args.rank,
timeout=timedelta(days=7))
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
...@@ -226,10 +227,24 @@ def write_args_to_tensorboard(): ...@@ -226,10 +227,24 @@ def write_args_to_tensorboard():
global_step=args.iteration) global_step=args.iteration)
def _initialize_mem_buffs(): def _set_jit_fusion_options():
"""Initialize manually allocated static memory.""" """Set PyTorch JIT layer fusion options."""
args = get_args() # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
# Initialize memory for checkpointed activations.
if args.distribute_checkpointed_activations:
mpu.init_checkpointed_activations_memory_buffer()
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
import torch import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
###### BIAS GELU FUSION/ NO AUTOGRAD ################ ###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn as nn
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
...@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0] inputs, scale_t[0]
) )
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward( input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0] output_grads, softmax_results, scale_t[0]
) )
return input_grads, None return input_grads, None
...@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward( softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
...@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function): ...@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal) attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied. mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision. softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling. scale: scaling factor used in input tensor scaling.
""" """
def __init__( def __init__(
...@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16 self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 and self.input_in_bf16),\ assert not (
'both fp16 and bf16 flags cannot be active at the same time.' self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
...@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert ( assert (
self.scale is None or softmax_in_fp32 self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
assert input.dim() == 4 assert input.dim() == 4
data_size = input.size()
query_seq_len = data_size[-2] if self.is_kernel_available(mask, *input.size()):
key_seq_len = data_size[-1] return self.forward_fused_softmax(input, mask)
attn_batch_size = data_size[0] * data_size[1]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel
if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_float16 and self.softmax_in_fp32: return self.forward_torch_softmax(input, mask)
input = input.float()
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
if self.scale is not None: def forward_fused_softmax(self, input, mask):
input = input * self.scale b, np, sq, sk = input.size()
mask_output = self.mask_func(input, mask) if mask is not None else input scale = self.scale if self.scale is not None else 1.0
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32: if self.attn_mask_type == AttnMaskType.causal:
if self.input_in_fp16: assert sq == sk, "causal mask is only for self attention"
probs = probs.half()
else: # input is 3D tensor (attn_batches, sq, sk)
probs = probs.bfloat16() input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal ...@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, parallel_output,
forward_method_parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
if get_key_value:
lm_output, presents = lm_output
# Output. # Output.
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if get_key_value:
output = [output, presents]
if labels is None: if labels is None:
return output return output
else: else:
...@@ -90,23 +82,22 @@ class GPTModel(MegatronModule): ...@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None,
forward_method_parallel_output=None): set_inference_key_value_memory=False,
inference_max_sequence_len=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
layer_past=layer_past, set_inference_key_value_memory=set_inference_key_value_memory,
get_key_value=get_key_value) inference_max_sequence_len=inference_max_sequence_len)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
get_key_value,
self.parallel_output, self.parallel_output,
forward_method_parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
return lm_output return lm_output
......
...@@ -379,8 +379,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -379,8 +379,10 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None,
get_key_value=False, pooling_sequence_index=0, set_inference_key_value_memory=False,
inference_max_sequence_len=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Encoder embedding. # Encoder embedding.
...@@ -393,10 +395,11 @@ class TransformerLanguageModel(MegatronModule): ...@@ -393,10 +395,11 @@ class TransformerLanguageModel(MegatronModule):
# Run encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
if self.encoder is not None: if self.encoder is not None:
encoder_output = self.encoder(encoder_input, encoder_output = self.encoder(
enc_attn_mask, encoder_input,
layer_past=layer_past, enc_attn_mask,
get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
else: else:
encoder_output = self.encoder_hidden_state encoder_output = self.encoder_hidden_state
else: else:
...@@ -424,12 +427,13 @@ class TransformerLanguageModel(MegatronModule): ...@@ -424,12 +427,13 @@ class TransformerLanguageModel(MegatronModule):
decoder_input = None decoder_input = None
# Run decoder. # Run decoder.
decoder_output = self.decoder(decoder_input, decoder_output = self.decoder(
dec_attn_mask, decoder_input,
layer_past=layer_past, dec_attn_mask,
get_key_value=get_key_value, encoder_output=encoder_output,
encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask) set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
if self.add_pooler and self.post_process: if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
......
...@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -27,11 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -123,6 +118,7 @@ class ParallelAttention(MegatronModule): ...@@ -123,6 +118,7 @@ class ParallelAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
...@@ -183,10 +179,53 @@ class ParallelAttention(MegatronModule): ...@@ -183,10 +179,53 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states, attention_mask, layer_past=None, # Inference key-value memory
get_key_value=False, encoder_output=None): self.inference_key_memory = None
self.inference_value_memory = None
self.inference_current_sequence_len = 0
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if set_inference_key_value_memory:
assert inference_max_sequence_len and inference_max_sequence_len > 0
self.inference_key_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1))
self.inference_value_memory = self._allocate_memory(
inference_max_sequence_len, hidden_states.size(1))
self.inference_current_sequence_len = 0
# Some consistency check.
if inference_max_sequence_len:
assert self.inference_current_sequence_len < \
self.inference_key_memory.size(0)
assert inference_max_sequence_len == \
self.inference_key_memory.size(0)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if not inference_max_sequence_len:
self.inference_key_memory = None
self.inference_value_memory = None
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
...@@ -227,18 +266,24 @@ class ParallelAttention(MegatronModule): ...@@ -227,18 +266,24 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None: # ===================================================
past_key, past_value = layer_past # Adjust key, value, and attention mask for inference
key_layer = torch.cat((past_key.type_as(key_layer), # ===================================================
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer), if inference_max_sequence_len:
value_layer), dim=0) # Adjust the range variables.
if get_key_value: start = self.inference_current_sequence_len
present = (key_layer, value_layer) self.inference_current_sequence_len += key_layer.size(0)
end = self.inference_current_sequence_len
# Copy key and values.
self.inference_key_memory[start:end, ...] = key_layer
self.inference_value_memory[start:end, ...] = value_layer
key_layer = self.inference_key_memory[:end, ...]
value_layer = self.inference_value_memory[:end, ...]
# Adjust attention mask
attention_mask = attention_mask[..., start:end, :end]
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
...@@ -275,22 +320,6 @@ class ParallelAttention(MegatronModule): ...@@ -275,22 +320,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
...@@ -346,9 +375,6 @@ class ParallelAttention(MegatronModule): ...@@ -346,9 +375,6 @@ class ParallelAttention(MegatronModule):
output, bias = self.dense(context_layer) output, bias = self.dense(context_layer)
if get_key_value:
output = [output, present]
return output, bias return output, bias
...@@ -435,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -435,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None,
layer_past=None, get_key_value=False): enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention(layernorm_output, self.self_attention(
attention_mask, layernorm_output,
layer_past=layer_past, attention_mask,
get_key_value=get_key_value) set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len)
if get_key_value:
attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -519,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -519,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
residual, residual,
self.hidden_dropout) self.hidden_dropout)
if get_key_value:
output = [output, presents]
return output return output
...@@ -542,8 +565,9 @@ class ParallelTransformer(MegatronModule): ...@@ -542,8 +565,9 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = None self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.activations_checkpoint_method = args.activations_checkpoint_method
self.checkpoint_num_layers = args.checkpoint_num_layers self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
...@@ -606,14 +630,49 @@ class ParallelTransformer(MegatronModule): ...@@ -606,14 +630,49 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
# Make sure memory is freed. def distribute_checkpointed_activations_helper(layer_number):
mpu.reset_checkpointed_activations_memory_buffer() """Distribute checkpointed activations across the tensor model
l = 0 Parallel ranks if the `distribute-checkpointed-activations
while l < self.num_layers: is on and either of the following conditions is met:
hidden_states = mpu.checkpoint( - it is not the first layer in the in the pipeline stage.
custom(l, l + self.checkpoint_num_layers), The first layer is used in the pipeline parallelism
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) and changing its shape throws error in the backward pass.
l += self.checkpoint_num_layers - we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage = (layer_number > 0)
is_first_pipeline_stage = (
mpu.get_pipeline_model_parallel_rank() == 0)
return self.distribute_checkpointed_activations and \
(not_first_layer_in_pipeline_stage or is_first_pipeline_stage)
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
distribute_checkpointed_activations_helper(l),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states return hidden_states
...@@ -627,18 +686,16 @@ class ParallelTransformer(MegatronModule): ...@@ -627,18 +686,16 @@ class ParallelTransformer(MegatronModule):
forward_step_func""" forward_step_func"""
self.input_tensor = input_tensor self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): encoder_output=None,
enc_dec_attn_mask=None,
set_inference_key_value_memory=False,
inference_max_sequence_len=None):
# Checks. # Checks.
if layer_past is not None: if inference_max_sequence_len:
assert get_key_value, \ assert self.activations_checkpoint_method is None, \
'for not None values in layer_past, ' \ 'inference does not work with activation checkpointing'
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
'get_key_value does not work with ' \
'activation checkpointing'
if self.pre_process: if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...@@ -655,28 +712,21 @@ class ParallelTransformer(MegatronModule): ...@@ -655,28 +712,21 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
encoder_output, encoder_output,
enc_dec_attn_mask) enc_dec_attn_mask)
else: else:
if get_key_value:
presents = []
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
past = None hidden_states = layer(
if layer_past is not None: hidden_states,
past = layer_past[index] attention_mask,
hidden_states = layer(hidden_states, encoder_output=encoder_output,
attention_mask, enc_dec_attn_mask=enc_dec_attn_mask,
encoder_output=encoder_output, set_inference_key_value_memory=set_inference_key_value_memory,
enc_dec_attn_mask=enc_dec_attn_mask, inference_max_sequence_len=inference_max_sequence_len)
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process:
...@@ -685,7 +735,5 @@ class ParallelTransformer(MegatronModule): ...@@ -685,7 +735,5 @@ class ParallelTransformer(MegatronModule):
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
else: else:
output = hidden_states output = hidden_states
if get_key_value:
output = [output, presents]
return output return output
...@@ -60,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region ...@@ -60,9 +60,7 @@ from .mappings import scatter_to_tensor_model_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks from .random import split_tensor_into_1d_equal_chunks
......
...@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter ...@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
...@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -198,6 +199,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -272,16 +304,30 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -272,16 +304,30 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and
world_size > 1)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
......
...@@ -37,37 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size ...@@ -37,37 +37,6 @@ from .initialize import get_tensor_model_parallel_world_size
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
def init_checkpointed_activations_memory_buffer():
"""Initializ the memory buffer for the checkpointed activations."""
args = get_args()
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \
'checkpointed activations memory buffer is already allocated.'
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
'checkpointed activations', numel, dtype, track_usage=False)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1): def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
...@@ -101,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -101,14 +70,21 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb) _lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor): def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks.""" """Break a tensor into equal 1D chunks."""
data = tensor.view(-1) partition_size = torch.numel(tensor) // \
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size() get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank() start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
return data[start_index:end_index] if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.""" """Opposite of above function, gather values from model parallel ranks."""
...@@ -250,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -250,8 +226,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset. tracked/set/reset.
""" """
@staticmethod @staticmethod
def forward(ctx, run_function, *args): def forward(ctx, run_function, distribute_checkpointed_activations, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_checkpointed_activations \
= distribute_checkpointed_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -263,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -263,16 +241,14 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep # Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank. # the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if distribute_checkpointed_activations:
ctx.input_0_shape = args[0].data.shape ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data) args[0].data = split_tensor_into_1d_equal_chunks(args[0].data,
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add( new_buffer=True)
args[0].data)
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.save_for_backward(*args)
return outputs return outputs
@staticmethod @staticmethod
...@@ -281,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -281,7 +257,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), " raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible") "please use .backward() if possible")
inputs = ctx.saved_tensors inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: if ctx.distribute_checkpointed_activations:
inputs[0].data = gather_split_1d_tensor(inputs[0].data) inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape) inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
...@@ -310,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -310,10 +286,11 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs) for inp in detached_inputs)
return (None,) + grads return (None, None) + grads
def checkpoint(function, *args): def checkpoint(function, distribute_checkpointed_activations, *args):
"""Checkpoint a model or part of the model. """Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint.""" This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function,
distribute_checkpointed_activations, *args)
...@@ -100,7 +100,7 @@ def get_megatron_optimizer(model): ...@@ -100,7 +100,7 @@ def get_megatron_optimizer(model):
args.clip_grad, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_ddp, args.use_contiguous_buffers_in_local_ddp,
args.bf16, args.bf16,
grad_scaler) grad_scaler)
...@@ -108,4 +108,4 @@ def get_megatron_optimizer(model): ...@@ -108,4 +108,4 @@ def get_megatron_optimizer(model):
return FP32Optimizer(optimizer, args.clip_grad, return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_ddp) args.use_contiguous_buffers_in_local_ddp)
...@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC): ...@@ -69,7 +69,7 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
use_contiguous_buffers_in_ddp): use_contiguous_buffers_in_local_ddp):
"""Input optimizer is the base optimizer for example Adam.""" """Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer self.optimizer = optimizer
...@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC): ...@@ -78,9 +78,9 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
if self.use_contiguous_buffers_in_ddp: if self.use_contiguous_buffers_in_local_ddp:
assert self.params_have_main_grad, \ assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad" "use of contiguous buffer requires that params have main grad"
...@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -193,12 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
""" """
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler): bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__( super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
...@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -323,7 +323,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# persist and therefore should not be deallocated.) # persist and therefore should not be deallocated.)
model_param.grad = None model_param.grad = None
if self.params_have_main_grad and \ if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_ddp: not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad. # For fp32 grads, we need to reset the grads to main grad.
...@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -335,7 +335,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Safe to de-reference model's main_grad after copying. # Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should # (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.) # persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp: if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
...@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -491,11 +491,11 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
use_contiguous_buffers_in_ddp): use_contiguous_buffers_in_local_ddp):
super(FP32Optimizer, self).__init__( super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
...@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -525,7 +525,7 @@ class FP32Optimizer(MegatronOptimizer):
# Safe to de-reference model's main_grad after copying. # Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should # (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.) # persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp: if not self.use_contiguous_buffers_in_local_ddp:
param.main_grad = None param.main_grad = None
# Clip gradients. # Clip gradients.
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import torch
import json
import threading
from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api
from megatron import get_args
from megatron import mpu
from megatron.text_generation_utils import generate
GENERATE_NUM = 0
lock = threading.Lock()
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
@staticmethod
def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, 0)
def put(self):
args = get_args()
print("request IP: " + str(request.remote_addr))
print(json.dumps(request.get_json()),flush=True)
print("current time: ", datetime.datetime.now())
sentences = request.get_json()["sentences"]
if len(sentences) > 128:
return "Maximum number of sentences is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 1:
return "tokens_to_generate must be an integer greater than 0"
all_probs = False
if "all_probs" in request.get_json():
all_probs = request.get_json()["all_probs"]
if not isinstance(all_probs, bool):
return "all_probs must be a boolean value"
temperature = args.temperature
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not isinstance(temperature, float) or not \
0.0 < temperature <= 100.0:
return "temperature must be a positive float less than or equal to 100.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
resp_sentences, resp_sentences_seg, output_logits, full_logits, tokens = generate(self.model, sentences, tokens_to_generate, all_probs, temperature, add_BOS)
if all_probs:
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits,
"all_logits": full_logits,
"tokens": tokens})
return jsonify({"sentences": resp_sentences,
"segments": resp_sentences_seg,
"logits": output_logits})
class MegatronServer(object):
def __init__(self, model):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model])
def run(self, url):
self.app.run(url, threaded=True, debug=False)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment