Commit 051f58f1 authored by liangjing's avatar liangjing
Browse files

v1

parent 0024a5c6
Pipeline #829 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.transformer_config import TransformerConfig
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
# def __init__(self, config: TransformerConfig, share_word_embeddings=True):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints.
"""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix=''):
""" Override sharded_state_dict when using distributed checkpointing.
keep_vars must always be set to True so that optimizer states
can be sharded.
"""
return self.state_dict(prefix=prefix, keep_vars=True)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super(Float16Module, self).__init__(config)
self.config = config
self.fp16 = config.fp16
self.bf16 = config.bf16
if self.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif self.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('Either config.fp16 or config.bf16 should be True.')
self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs):
if parallel_state.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if parallel_state.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
""" Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix=''):
""" Retrieve sharded_state_dict from the module being wrapped.
"""
return self.module.sharded_state_dict(prefix=prefix)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import re
from contextlib import nullcontext
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor
class TransformerBlock(MegatronModule):
"""Transformer class."""
def __init__(
self,
config: TransformerConfig,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True,
post_process=True,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.self_attn_mask_type = self_attn_mask_type
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
# required for pipeline parallel schedules
self.input_tensor = None
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
self.num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
self._build_layers()
def _build_layers(self):
# Transformer layers.
# @jcasper can we improve how we deal with layer_number?
# currently it's only used in CoreAttention?
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
def build_layer(layer_number):
layer = TransformerLayer(
config=self.config,
layer_number=layer_number,
self_attn_mask_type=self.self_attn_mask_type,
)
return layer
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# Interleaved pipeline parallelism:
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
num_layers_per_virtual_rank = self.num_layers_per_pipeline_rank // vp_size
num_layers_to_build = num_layers_per_virtual_rank
else:
# Non-interleaved pipeline parallelism:
# Each stage gets a contiguous set of layers.
num_layers_to_build = self.num_layers_per_pipeline_rank
# offset is implicit in TransformerLayer
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(num_layers_to_build)])
# # TODO: add back standalone_embedding_stage
# if self.num_layers == 0:
# # When a standalone embedding stage is used (e.g.,
# # args.standalone_embedding_stage == True), virtual pipeline ranks
# # on pipeline rank 0 will have zero transformer layers assigned to
# # them. This results in the model's input and output tensors to be
# # the same, which will cause failure for certain output tensor
# # optimizations (e.g., pipeline output deallocation). To remedy
# # this, we assign a 'no-op' layer on these ranks, which will
# # disconnect the input tensor from the output tensor.
# self.num_layers = 1
# self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)])
# else:
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = TENorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
persist_layer_norm=self.config.persist_layer_norm,
sequence_parallel=self.config.sequence_parallel,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
normalization=self.config.normalization,
)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask, rotary_pos_emb):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, *args, **kwargs)
return x_
return custom_forward
if self.config.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers_per_pipeline_rank:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.config.recompute_num_layers),
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
rotary_pos_emb,
)
l += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers_per_pipeline_rank):
if l < self.config.recompute_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
rotary_pos_emb,
)
else:
hidden_states = custom(l, l + 1)(hidden_states, attention_mask, rotary_pos_emb)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, inference_params=None, rotary_pos_emb=None):
# hidden_states (float): [s, b, h]
# attention_mask (bool): [1, 1, s, s]
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True,
)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
if self.config.fp8:
import transformer_engine # To keep out TE dependency when not training in fp8
if self.config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif self.config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=self.config.fp8_margin,
interval=self.config.fp8_interval,
fp8_format=fp8_format,
amax_compute_algo=self.config.fp8_amax_compute_algo,
amax_history_len=self.config.fp8_amax_history_len,
override_linear_precision=(False, False, not self.config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group()
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
fp8_context = nullcontext()
with rng_context and fp8_context:
# Forward pass.
if self.config.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
)
else:
for layer in self.layers:
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
inference_params=inference_params,
)
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def sharded_state_dict(self, prefix=''):
sharded_state_dict = {}
layer_prefix = f'{prefix}layers.'
for layer in self.layers:
sharded_state_dict.update(layer.sharded_state_dict(prefix=layer_prefix))
if self.post_process and self.post_layer_norm:
state_dict = self.state_dict(keep_vars=True)
tensor = state_dict['final_layernorm.weight']
layer_name = f'{prefix}final_layernorm.weight'
sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(tensor, layer_name)
# RMSNorm doesn't have bias.
if 'final_layernorm.bias' in state_dict.keys():
tensor = state_dict['final_layernorm.bias']
layer_name = f'{prefix}final_layernorm.bias'
sharded_state_dict[layer_name] = make_sharded_tensor_for_checkpoint(
tensor, layer_name
)
return sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable
import torch
import torch.nn.functional as F
from megatron.core import ModelParallelConfig
from megatron.core.utils import init_method_normal, scaled_init_method_normal
@dataclass
class TransformerConfig(ModelParallelConfig):
"""Configuration object for megatron-core transformers.
Attributes:
# model architecture
num_layers (int): Number of transformer layers in a transformer block.
hidden_size (int): Transformer hidden size.
ffn_hidden_size (int): Transformer Feed-Forward Network hidden size.
This is set to 4*hidden_size if not provided. Defaults to None.')
num_attention_heads (int): Number of transformer attention heads.
kv_channels (int): Projection weights dimension in multi-head attention.
This is set to hidden_size // num_attention_heads if not provided.
Defaults to None.
num_query_groups (int): Number of query groups for group query attention. If None, normal attention is used.
hidden_dropout (float): Dropout probability for transformer hidden state. Defaults to 0.1.
attention_dropout (float): Post attention dropout probability. Defaults to 0.1.
fp32_residual_connection (bool): If true, move residual connections to fp32.
apply_residual_connection_post_layernorm (bool): If true, uses the original BERT residule connection ordering.
Defaults to False.
layernorm_epsilon (float): Layernorm epsilon. Defaults to 1e-5.
layernorm_zero_centered_gamma (bool): if set to 'True', the LayerNorm is adjusted to center the gamma values
around 0. This improves numerical stability. Defaults to False.
add_bias_linear (bool): Include a bias term in all linear layers (QKV projections, after core attention, and two
in MLP layer). Default is True.
gated_linear_unit (bool): Use a gated linear unit for the first linear layer in the MLP. Defaults to False.
activation_func (Callable): Activation function to use for the non-linearity in the MLP. Defaults to F.gelu.
# initialization
init_method (Callable): Method to initialize weights. Note that bias is always set to
zero. Should be a function that takes a single Tensor and
initializes it. Defaults to
megatron.core.utils.init_method_normal(init_method_std) which is
torch.nn.init.normal_ with mean=0.0 and std=init_method_Std.
output_layer_init_method (Callable): Method to initialize weights of the output layer of
both attention and MLP blocks. Defaults to
megatron.core.utils.scaled_init_method_normal(init_method_std)
which is torch.nn.init.normal_ with mean=0.0 and
std=init_method_std / math.sqrt(2.0 * num_layers).
init_method_std (float): Standard deviation of the zero mean normal for the default
initialization method, not used if init_method and
output_layer_init_method are provided. Defaults to 0.02.
# mixed-precision
apply_query_key_layer_scaling (bool): If true, scale Q * K^T by 1 / layer-number. Defaults to True.
attention_softmax_in_fp32 (bool): If true, run attention masking and softmax in fp32.
This should be true if apply_query_key_layer_scaling is true.
# fusion
bias_gelu_fustion (bool): If true, fuses bias and gelu. Defaults to False.
masked_softmax_fusion (bool): If true, uses softmax fusion.
persist_layer_norm (bool): If true, uses the persistent fused layer norm kernel.
This kernel only supports a fixed set of hidden sizes.
Defaults to False.
bias_dropout_fusion (bool): If true, uses bias dropout fusion.
# activation recomputation
recompute_granularity (str): megatron-core supports 'selective' activation checkpointing where only the memory
intensive part of attention is checkpointed. These memory intensive activations
are also less compute intensive which makes activation checkpointing more efficient
for LLMs (20B+). See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. 'full' will checkpoint
the entire transformer layer. Must be 'selective' or 'full'. 'selective' always uses all layers.
Defaults to None.
recompute_method (str): uniform will uniformly divide the total number of transformer layers in a transformer
block and recompute the input activation of each divided chunk at the specified
granularity. block will recompute the input activations for only a set number of
transformer layers per pipeline stage. The rest of the layers in the pipeline stage
will not have any activations recomputed. Must be 'uniform' or 'block'. Defaults to
None.
recompute_num_layers (int): When recompute_method is uniform, recompute_num_layers is the number of transformer
layers in each uniformly divided recompute unit. When recompute_method is block,
recompute_num_layers is the number of transformer layers to recompute within each
pipeline stage. Must be None for 'selective' activation checkpointing. Defaults to None.
distribute_saved_activations (bool): If true, distribute recomputed activations across the model parallel
group. Defaults to None.
# fp8 related (via Transformer Engine). For detailed info, refer the the Transformer Engine docs at
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html
fp8 (str): If set, enables the use of FP8 precision through Transformer Engine. There are 2 predefined choices: (1) 'e4m3'
uniformly uses e4m3 for all FP8 tensors, (2) 'hybrid' uses e4m3 for all FP8 activation and weight tensors and
e5m2 for all FP8 output activation gradient tensors. Defaults to None.
fp8_margin (int): Margin for the scaling factor computation.
fp8_interval (int): Controls how often the scaling factor is recomputed.
fp8_amax_history_len (int): The length of the amax history window used for scaling factor computation.
fp8_amax_compute_algo (str): Algorithm used for choosing the `amax` value for the scaling factor computation.
There are 2 predefined choices: `max` chooses the largest `amax` in the history
window, while `most_recent` always chooses the most recently seen value.
fp8_wgrad (bool): When set to False, override FP8 config options and do the wgrad computation in higher precision.
Defaults to True.
# Experimental
normalization (str): Swtich b/w `LayerNorm` and `RMSNorm` as normalization layers. For now, these are primarily
used by Transformer-Engine's layers like `LayerNormLinear`. Default value is `LayerNorm`.
"""
# model architecture
num_layers: int = 0
hidden_size: int = 0
num_attention_heads: int = 0
num_query_groups: int = None
ffn_hidden_size: int = None
kv_channels: int = None
hidden_dropout: float = 0.1
attention_dropout: float = 0.1
fp32_residual_connection: bool = False
# @jcasper should we keep this option?
apply_residual_connection_post_layernorm: bool = False
layernorm_epsilon: float = 1e-5
layernorm_zero_centered_gamma: bool = False
add_bias_linear: bool = True
gated_linear_unit: bool = False
activation_func: Callable = F.gelu
# initialization
init_method: Callable = None
output_layer_init_method: Callable = None
init_method_std: float = 0.02
# mixed-precision
apply_query_key_layer_scaling: bool = True
attention_softmax_in_fp32: bool = True
# communication
# fusion
bias_gelu_fusion: bool = False # TODO: this should be bias_activation_fusion ?
masked_softmax_fusion: bool = False
persist_layer_norm: bool = False
bias_dropout_fusion: bool = False # TODO: this should be bias_dropout_add_fusion?
# activation recomputation
recompute_granularity: str = None
recompute_method: str = None
recompute_num_layers: int = None
distribute_saved_activations: bool = None
# fp8 related
fp8: str = None
fp8_margin: int = 0
fp8_interval: int = 1
fp8_amax_history_len: int = 1
fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True
# experimental section (TODO: move to apt. section above once stable)
normalization: bool = "LayerNorm" # alt value supported by TE: "RMSNorm"
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
super().__post_init__()
if self.fp16 and self.bf16:
raise ValueError(
f'Only one of self.fp16: {self.fp16} and self.bf16 {self.bf16} should be True.'
)
if self.num_attention_heads % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_query_groups % self.tensor_model_parallel_size != 0:
raise ValueError(
f"num_query_groups ({self.num_query_groups}) must be a multiple of "
f"tensor_model_parallel_size ({self.tensor_model_parallel_size})."
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.recompute_granularity is not None:
if not self.recompute_granularity in ['full', 'selective']:
raise ValueError(
f'When using recompute_granuarlity: {self.recompute_granularity} must be "full" or "selective".'
)
if self.recompute_method is not None:
if not self.recompute_method in ['block', 'uniform']:
raise ValueError(
f'recompute_method: {self.recompute_method} must be "block" or "uniform".'
)
elif self.recompute_granularity != 'selective':
raise ValueError(
f'Using recompute_granularity: {self.recompute_granularity} so recompute_method must be "block" or "uniform"'
)
if self.recompute_granularity != 'selective' and self.recompute_num_layers is None:
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be between '
f'1 and num_layers_per_pipeline_rank: {self.num_layers // self.pipeline_model_parallel_size}'
)
elif (
self.recompute_granularity == 'selective' and self.recompute_num_layers is not None
):
raise ValueError(
f'When using recompute_granularity: {self.recompute_granularity} recompute_num_layers must be None.'
)
if self.distribute_saved_activations and self.sequence_parallel:
raise ValueError(
f'distribute_saved_activations: {self.distribute_saved_activations} must be false when sequence parallel is enabled: {self.sequence_parallel}'
)
if self.virtual_pipeline_model_parallel_size is not None:
if not self.num_layers % self.virtual_pipeline_model_parallel_size == 0:
raise ValueError(
f'num_layers: {self.num_layers} must be divisible by virtual_model_parallel_size {self.virtual_pipeline_model_parallel_size}'
)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_gelu_fusion:
if not self.add_bias_linear:
raise ValueError(
"When bias_gelu_fusion is True, add_bias_linear must also be True."
)
if self.activation_func != F.gelu:
raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std, self.num_layers
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import re
from functools import partial
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import (
ShardedObject,
ShardedTensor,
ShardedTensorFactory,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
class TransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
layer_number: int = 1,
self_attn_mask_type=AttnMaskType.padding,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.layer_number = layer_number + self._get_layer_offset()
self.self_attn_mask_type = self_attn_mask_type
# Layernorm on the input data.
# TODO: add pytorch only layernorm
self.input_layernorm = IdentityOp(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
persist_layer_norm=self.config.persist_layer_norm,
sequence_parallel=self.config.sequence_parallel,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
normalization=self.config.normalization,
)
# Self attention.
self.self_attention = SelfAttention(
config=self.config, layer_number=layer_number, attn_mask_type=self_attn_mask_type,
)
# Layernorm on the attention output
self.post_self_attn_layernorm = IdentityOp(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
persist_layer_norm=self.config.persist_layer_norm,
sequence_parallel=self.config.sequence_parallel,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
normalization=self.config.normalization,
)
# MLP
self.mlp = MLP(config=self.config)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
def _get_layer_offset(self):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset
def forward(
self,
hidden_states,
attention_mask,
encoder_output=None,
enc_dec_attn_mask=None,
inference_params=None,
rotary_pos_emb=None,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
# Residual connection.
if self.config.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
bias_dropout_add_func = get_bias_dropout_add(self.training, self.config.bias_dropout_fusion)
# bias_dropout_add fusion returning fp32 instead of bf16
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output_with_bias, residual, self.config.hidden_dropout
)
# Layer norm post the self attention.
layernorm_output = self.post_self_attn_layernorm(layernorm_input)
# MLP.
mlp_output_with_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.config.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output_with_bias, residual, self.config.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=output, requires_grad=output.requires_grad, keep_graph=True
)
return output
def sharded_state_dict(self, prefix=''):
# state_dict = self.state_dict(prefix=prefix, keep_vars=True)
state_dict = self.state_dict(keep_vars=True)
tensor_parallel_layers_axis_map = {
'self_attention.linear_qkv.weight': 0,
'self_attention.linear_qkv.bias': 0,
'self_attention.linear_proj.weight': 1,
'mlp.linear_fc1.weight': 0,
'mlp.linear_fc1.bias': 0,
'mlp.linear_fc2.weight': 1,
}
offset = self._get_layer_offset()
num_layers = self.config.num_layers
sharded_state_dict = {}
for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
global_layer_offset = self.layer_number - 1 # self.layer_number starts at 1
layer_key = f'{prefix}{global_layer_offset - offset}.{layer_name}' # module list index in TransformerBlock
sharded_offsets = [(0, global_layer_offset, num_layers)] # PP sharding
# TODO: move it to MLP after merging the "sharded_state_dict modularization" MR
is_glu_weight = (
layer_name == 'mlp.linear_fc1.weight' and self.mlp.config.gated_linear_unit
)
if layer_name in tensor_parallel_layers_axis_map:
tp_axis = tensor_parallel_layers_axis_map[layer_name]
# TP sharding
if not is_glu_weight:
sharded_offsets.append(
[
tp_axis + 1, # +1 for PP dimension
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
]
)
replica_id = parallel_state.get_data_parallel_rank()
else:
replica_id = (
parallel_state.get_data_parallel_rank()
* parallel_state.get_data_parallel_world_size()
+ parallel_state.get_tensor_model_parallel_rank()
)
if layer_name.endswith('._extra_state'):
sharded_state_dict[layer_key] = ShardedObject(
f'{prefix}{layer_name}',
tensor,
(num_layers,),
(global_layer_offset,),
replica_id,
)
elif is_glu_weight:
# We must split the tensor into 2 parts, each sharded separately.
# This requires a ShardedTensorFactory which `chunk`s during saving
# and `cat`s during loading
assert tp_axis == 0, f'TP axis for GLU weight should be 0, got: {tp_axis}'
tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
sh_ten_builder = partial(
ShardedTensor.from_rank_offsets, replica_id=replica_id, prepend_axis_num=1
) # for PP sharding
# NOTE: passing `tp_axis` as argument due to late binding in closures
def sh_ten_build_fn(key: str, t: torch.Tensor, tp_axis=tp_axis):
offset_w = (tp_axis + 1, tp_rank, tp_size * 2)
offset_v = (tp_axis + 1, tp_size + tp_rank, tp_size * 2)
with torch.no_grad():
tensor_w, tensor_v = torch.chunk(t, 2, dim=tp_axis)
return [
sh_ten_builder(key, tensor_w, *sharded_offsets, offset_w),
sh_ten_builder(key, tensor_v, *sharded_offsets, offset_v),
]
def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
return torch.cat(sub_state_dict)
sharded_state_dict[layer_key] = ShardedTensorFactory(
f'{prefix}{layer_name}', tensor, sh_ten_build_fn, sh_ten_merge_fn
)
else:
sharded_state_dict[layer_key] = ShardedTensor.from_rank_offsets(
f'{prefix}{layer_name}',
tensor,
*sharded_offsets,
replica_id=replica_id,
prepend_axis_num=1, # for PP sharding
)
return sharded_state_dict
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Utilities for transformer layers."""
import torch
from megatron import get_args
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
if get_args().perform_initialization:
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
return gelu_impl(x)
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@torch.jit.script
def erf_gelu(x):
return (
x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions used throughout Megatron core"""
from functools import reduce
import math
import operator
from functools import reduce
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
......@@ -20,22 +22,38 @@ def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr):
def get_attr_wrapped_model(model, attr, allow_none=True):
"""Get an attribute from a wrapped model"""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
while not hasattr(model, attr):
if allow_none:
def condition(model, attr):
return not hasattr(model, attr)
else:
def condition(model, attr):
return getattr(model, attr, None) is None
while condition(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
def get_model_config(model):
return get_attr_wrapped_model(model, 'config', allow_none=False)
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
......@@ -46,16 +64,17 @@ class GlobalMemoryBuffer:
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
if (
self.buffer.get((name, dtype), None) is None
or self.buffer[(name, dtype)].numel() < required_len
):
self.buffer[(name, dtype)] = torch.empty(
required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
......@@ -65,15 +84,11 @@ def _kernel_make_viewless_tensor(inp, requires_grad):
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
......@@ -83,13 +98,16 @@ class MakeViewlessTensor(torch.autograd.Function):
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
......@@ -110,11 +128,12 @@ def make_viewless_tensor(inp, requires_grad, keep_graph):
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
def assert_viewless_tensor(tensor, extra_msg=None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
[assert_viewless_tensor(t) for t in tensor]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
......@@ -125,11 +144,64 @@ def assert_viewless_tensor(tensor, extra_msg = None):
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
assert_viewless_tensor(
tensor,
extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
% ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
)
tensor.data = new_data_tensor
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def make_tp_sharded_tensor_for_checkpoint(tensor, key, tp_axis=0, replica_id=None, **kwargs):
""" Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group. """
return ShardedTensor.from_rank_offsets(
key,
tensor,
(
tp_axis,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
),
replica_id=parallel_state.get_data_parallel_rank() if replica_id is None else replica_id,
**kwargs,
)
def make_sharded_tensor_for_checkpoint(tensor, key, **kwargs):
""" Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group). """
return ShardedTensor.from_rank_offsets(
key,
tensor,
replica_id=parallel_state.get_data_parallel_rank()
* parallel_state.get_data_parallel_world_size()
+ parallel_state.get_tensor_model_parallel_rank(),
**kwargs,
)
......@@ -2,25 +2,27 @@
"""Blendable dataset."""
import hashlib
import os
import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron.core import mpu
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
def __init__(self, datasets, weights, size, *,
data_cache_path=None):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = 0
for dataset in self.datasets:
self.size += len(dataset)
self.size = size
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
......@@ -28,19 +30,85 @@ class BlendableDataset(torch.utils.data.Dataset):
assert sum_weights > 0.0
weights /= sum_weights
# Build indecies.
# Build indicies.
def _build_indices():
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
dataset_index = np.zeros(self.size, dtype=np.uint8)
dataset_sample_index = np.zeros(self.size, dtype=np.int64)
from megatron.data import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
helpers.build_blending_indices(dataset_index, dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print_rank_0('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
return dataset_index, dataset_sample_index
desc = "Blendable dataset\n\n"
desc += "Datasets:\n"
for dataset in datasets:
desc += dataset.desc + "\n\n"
desc += f"Weights: {weights}\n"
desc += f"Size: {size}\n"
self.desc = desc
if data_cache_path:
desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest()
desc_path = os.path.join(data_cache_path, desc_hash + ".dsc")
index_path = os.path.join(data_cache_path, desc_hash + "_index.npy")
sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy")
cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path)
cache_success = True
if torch.distributed.get_rank() == 0 and not cache_hit:
print(' > WARNING: could not find index map files for blendable'
' dataset, building indices on rank 0 ...', flush=True)
dataset_index, dataset_sample_index = _build_indices()
try:
os.makedirs(os.path.dirname(index_path), exist_ok=True)
with open(desc_path, 'wt') as fd:
fd.write(desc)
np.save(index_path, dataset_index, allow_pickle=True)
np.save(sample_index_path, dataset_sample_index,
allow_pickle=True)
except OSError:
print(f'There was an error trying to create the data cache directory ({data_cache_path})')
print('or a file in it. This is set with the --data-cache-path argument. Please')
print('ensure you have write access to this directory or specify one that you do have')
print('write access to.')
cache_success = False
counts = torch.cuda.LongTensor([cache_success])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
if counts[0].item() != (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())):
print_rank_0("Data index creation unsuccessful, exiting.")
exit()
# Load on all ranks.
print_rank_0(f'> loading blendable dataset index: {index_path}')
self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_index.size == self.size
print_rank_0(f'> loading blendable dataset sample index: {sample_index_path}')
self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode='r')
assert self.dataset_sample_index.size == self.size
else:
self.dataset_index, self.dataset_sample_index = _build_indices()
# Check size
_ = self.__getitem__(self.size - 1)
try:
_ = self.__getitem__(self.size)
raise RuntimeError('BlendedDataset size is improperly bounded')
except IndexError:
pass
print_rank_0('> size of blendable dataset: '
'{} samples'.format(self.size))
def __len__(self):
......
......@@ -37,8 +37,9 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPE_MULTIMODAL = 'multimodal'
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5, DSET_TYPE_MULTIMODAL]
def get_datasets_weights_and_num_samples(data_prefix,
......@@ -419,10 +420,48 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_train_valid_test_datasets_with_prefixes(data_impl,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob, short_seq_prob, seed,
seed,
skip_warmup,
train_data_prefix=None,
valid_data_prefix=None,
test_data_prefix=None,
binary_head=False,
max_seq_length_dec=None,
dataset_type='standard_bert'):
print_rank_0("Separate data paths provided for train, valid & test.")
train_dataset, valid_dataset, test_dataset = None, None, None
# Single dataset.
if train_data_prefix is not None:
train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0],
max_seq_length, seed, skip_warmup,
binary_head, max_seq_length_dec,
dataset_type=dataset_type)
if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1],
max_seq_length, seed, False,
binary_head, max_seq_length_dec,
dataset_type=dataset_type)
if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2],
max_seq_length, seed, False,
binary_head, max_seq_length_dec,
dataset_type=dataset_type)
return (train_dataset, valid_dataset, test_dataset)
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, seed,
skip_warmup, binary_head=False,
max_seq_length_dec=None,
dataset_type='standard_bert'):
......@@ -431,8 +470,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
max_seq_length, seed,
skip_warmup,
binary_head,
max_seq_length_dec,
......@@ -442,6 +480,10 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
train_num_samples, valid_num_samples, test_num_samples = map(
sum,
zip(*datasets_train_valid_test_num_samples)
)
# Build individual datasets.
train_datasets = []
......@@ -451,9 +493,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, max_seq_length_dec,
dataset_type=dataset_type)
max_seq_length, seed, skip_warmup, binary_head,
max_seq_length_dec, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
......@@ -464,13 +505,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
......@@ -478,24 +519,15 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob, short_seq_prob, seed,
max_seq_length, seed,
skip_warmup, binary_head,
max_seq_length_dec,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
dataset_type,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
......@@ -521,10 +553,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset
from megatron.data.ict_dataset import ICTDataset
from megatron.data.t5_dataset import T5Dataset
def build_split_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
......@@ -535,18 +564,65 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
dataset = build_dataset(
name, data_prefix, data_impl,
train_valid_test_num_samples[index], max_seq_length,
seed, skip_warmup, binary_head, max_seq_length_dec,
dataset_type, indexed_dataset)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_split_dataset(0, 'train')
valid_dataset = build_split_dataset(1, 'valid')
test_dataset = build_split_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(name, data_prefix, data_impl, max_num_samples,
max_seq_length, seed, skip_warmup, binary_head,
max_seq_length_dec, dataset_type='standard_bert',
indexed_dataset=None):
from megatron.data.bert_dataset import BertDataset
from megatron.data.ict_dataset import ICTDataset
from megatron.data.t5_dataset import T5Dataset
from megatron.data.multimodal_dataset import MultiModalDataset
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
if indexed_dataset is None:
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
dataset_type,
skip_warmup)
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_num_samples=max_num_samples,
max_seq_length=max_seq_length,
seed=seed,
)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(
args.titles_data_path,
data_impl,
dataset_type,
skip_warmup)
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
......@@ -556,47 +632,51 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
**kwargs
)
elif dataset_type == DSET_TYPE_T5:
args = get_args()
dataset = T5Dataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
masked_lm_prob=args.mask_prob,
max_seq_length_dec=max_seq_length_dec,
short_seq_prob=short_seq_prob,
short_seq_prob=args.short_seq_prob,
**kwargs
)
elif dataset_type == DSET_TYPE_BERT:
args = get_args()
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
binary_head=binary_head,
**kwargs
)
elif dataset_type == DSET_TYPE_MULTIMODAL:
args = get_args()
dataset = MultiModalDataset(
name=name,
data_prefix=data_prefix,
indexed_dataset=indexed_dataset,
num_samples=max_num_samples,
seq_length=max_seq_length,
seed=seed,
img_h=args.img_h,
img_w=args.img_w,
)
else:
raise NotImplementedError("Dataset type not fully implemented.")
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
def get_indexed_dataset_(data_prefix, data_impl, dataset_type, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
multimodal = dataset_type == DSET_TYPE_MULTIMODAL
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
skip_warmup,
multimodal)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
......
......@@ -2,6 +2,7 @@
"""GPT style dataset."""
import hashlib
import os
import time
......@@ -22,7 +23,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_data_prefix=None,
valid_data_prefix=None,
test_data_prefix=None,
return_doc_ids=False):
return_doc_ids=False, *,
data_cache_path=None):
"""Build train, valid, and test datasets."""
if data_prefix:
......@@ -33,13 +35,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup)
seq_length, seed, skip_warmup,
data_cache_path=data_cache_path)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
train_num_samples, valid_num_samples, test_num_samples = map(
sum,
zip(*datasets_train_valid_test_num_samples)
)
# Build individual datasets.
train_datasets = []
......@@ -50,7 +57,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup,
return_doc_ids)
return_doc_ids,
data_cache_path=data_cache_path)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
......@@ -61,13 +69,16 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples,
data_cache_path=data_cache_path)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples,
data_cache_path=data_cache_path)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples,
data_cache_path=data_cache_path)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
......@@ -79,18 +90,25 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Single dataset.
if train_data_prefix is not None:
train_dataset = build_dataset("train", train_data_prefix, data_impl,
splits_string,
train_valid_test_num_samples[0],
seq_length, seed, skip_warmup)
seq_length, seed, skip_warmup,
data_cache_path=data_cache_path)
if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
splits_string,
train_valid_test_num_samples[1],
seq_length, seed, False)
seq_length, seed, False,
data_cache_path=data_cache_path)
if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl,
splits_string,
train_valid_test_num_samples[2],
seq_length, seed, False)
seq_length, seed, False,
data_cache_path=data_cache_path)
return (train_dataset, valid_dataset, test_dataset)
......@@ -98,7 +116,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup,
return_doc_ids=False):
return_doc_ids=False, *,
data_cache_path=None):
"""Build train, valid, and test datasets."""
# Indexed dataset.
......@@ -126,11 +145,12 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
dataset = GPTDataset(name, data_prefix, documents, indexed_dataset,
splits_string,
train_valid_test_num_samples[index],
seq_length, seed,
return_doc_ids)
return_doc_ids,
data_cache_path=data_cache_path)
return dataset
train_dataset = build_dataset(0, 'train')
......@@ -140,37 +160,45 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples,
seq_length, seed, skip_warmup):
def build_dataset(dataset_name, data_prefix, data_impl,
splits_string, num_samples,
seq_length, seed, skip_warmup,
*,
data_cache_path=None):
dataset = None
if len(data_prefix) == 1:
dataset = _build_dataset(dataset_name,
data_prefix[0], data_impl,
num_samples, seq_length,
seed, skip_warmup)
dataset = _build_dataset(dataset_name, data_prefix[0], data_impl,
splits_string, num_samples, seq_length,
seed, skip_warmup,
data_cache_path=data_cache_path)
else:
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, num_samples)
prefixes, weights, dataset_num_samples = output
num_samples = sum(dataset_num_samples)
# Build individual datasets.
datasets = []
for i in range(len(prefixes)):
ds = _build_dataset(dataset_name, prefixes[i],
data_impl, dataset_num_samples[i],
seq_length, seed, skip_warmup)
ds = _build_dataset(dataset_name, prefixes[i], data_impl,
splits_string, dataset_num_samples[i],
seq_length, seed, skip_warmup,
data_cache_path=data_cache_path)
if ds:
datasets.append(ds)
if datasets:
dataset = BlendableDataset(datasets, weights)
dataset = BlendableDataset(datasets, weights, num_samples,
data_cache_path=data_cache_path)
return dataset
def _build_dataset(dataset_name, data_prefix, data_impl,
num_samples, seq_length, seed, skip_warmup):
def _build_dataset(dataset_name, data_prefix, data_impl, splits_string,
num_samples, seq_length, seed, skip_warmup,
*,
data_cache_path=None):
"""
Build dataset. This method is called when individual
train, valid, test datasets are provided
......@@ -190,9 +218,9 @@ def _build_dataset(dataset_name, data_prefix, data_impl,
documents = np.arange(start=0, stop=total_num_of_documents,
step=1, dtype=np.int32)
dataset = GPTDataset(dataset_name, data_prefix,
documents, indexed_dataset,
num_samples, seq_length, seed)
dataset = GPTDataset(dataset_name, data_prefix, documents, indexed_dataset,
splits_string, num_samples, seq_length, seed,
data_cache_path=data_cache_path)
return dataset
......@@ -216,8 +244,9 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed,
return_doc_ids=False):
splits_string, num_samples, seq_length, seed,
return_doc_ids=False, *,
data_cache_path=None):
self.name = name
self.indexed_dataset = indexed_dataset
......@@ -228,10 +257,11 @@ class GPTDataset(torch.utils.data.Dataset):
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \
self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = \
_build_index_mappings(self.name, data_prefix,
documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
splits_string, num_samples, seq_length, seed,
data_cache_path=data_cache_path)
def __len__(self):
......@@ -278,7 +308,9 @@ class GPTDataset(torch.utils.data.Dataset):
def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples, seq_length, seed):
splits_string, num_samples, seq_length, seed,
*,
data_cache_path):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
sample-idx: is the start document index and document offset for each
......@@ -293,21 +325,45 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
index_prefix = '{}_indexmap'.format(name)
index_prefix += '_{}ns'.format(num_samples)
index_prefix += '_{}sl'.format(seq_length)
index_prefix += '_{}s'.format(seed)
_filename = data_prefix + '_' + index_prefix
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
desc = "GPT Dataset\n\n"
desc += f"Data prefix {data_prefix}\n"
desc += f"Dataset name {name}\n"
desc += f"Number of samples {num_samples}\n"
desc += f"Sequence length {seq_length}\n"
desc += f"Random seed {seed}\n"
desc += f"Split {splits_string}\n"
desc_hash = hashlib.md5(desc.encode('utf-8')).hexdigest()
desc_filename = desc_hash + ".dsc"
doc_idx_filename = desc_hash + '_doc_idx.npy'
sample_idx_filename = desc_hash + '_sample_idx.npy'
shuffle_idx_filename = desc_hash + '_shuffle_idx.npy'
# Look for cache in main data dir first to avoid unnecessary
# duplication, then look in data-cache-path if specified,
# If nothing is found, use the last path looked in
build_indices = True
prefixes = [os.path.join(os.path.dirname(data_prefix), 'index-cache')]
if data_cache_path is not None:
prefixes.append(data_cache_path)
for prefix in prefixes:
idx_path = {
'desc': os.path.join(prefix, desc_filename),
'doc': os.path.join(prefix, doc_idx_filename),
'sample': os.path.join(prefix, sample_idx_filename),
'shuffle': os.path.join(prefix, shuffle_idx_filename)
}
for f in idx_path.values():
if not os.path.isfile(f):
break
else:
# Found our files!
build_indices = False
break
data_cache_dir = os.path.dirname(idx_path['desc'])
data_cache_success = True
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
if build_indices and torch.distributed.get_rank() == 0:
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
......@@ -330,7 +386,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert last_epoch_num_samples >= 0, \
'last epoch number of samples should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \
'last epoch number of samples exceeded max value.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
......@@ -349,11 +405,19 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
print(string.format(last_epoch_num_samples,
num_samples_per_epoch), flush=True)
try:
os.makedirs(data_cache_dir, exist_ok=True)
# description
with open(idx_path['desc'], 'wt') as fd:
fd.write(desc)
# doc-idx.
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
np.save(idx_path['doc'], doc_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# sample-idx.
......@@ -365,7 +429,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
np.save(idx_path['sample'], sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx.
......@@ -378,38 +442,44 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_,
sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
np.save(idx_path['shuffle'], shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
except OSError:
print(f'There was an error trying to create the data cache directory ({data_cache_dir})')
print('or a file in it. This defaults to a directory "index-cache" within the directory')
print('the data files are in and can be set with the --data-cache-path argument. Please')
print('ensure you have write access to this directory or specify one that you do have')
print('write access to.')
data_cache_success = False
counts = torch.cuda.LongTensor([data_cache_success])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
if counts[0].item() != (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())):
print_rank_0("Data index creation unsuccessful, exiting.")
exit()
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}")
doc_idx = np.load(idx_path['doc'], allow_pickle=True, mmap_mode='r')
print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}")
sample_idx = np.load(idx_path['sample'], allow_pickle=True, mmap_mode='r')
print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}")
shuffle_idx = np.load(idx_path['shuffle'], allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx, index_prefix
return doc_idx, sample_idx, shuffle_idx, desc, desc_hash
def _num_tokens(documents, sizes):
......@@ -517,3 +587,4 @@ def _build_shuffle_idx(num_samples, total_size, np_rng):
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
......@@ -55,7 +55,7 @@ def make_builder(out_file, impl, vocab_size=None):
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
def make_dataset(path, impl, skip_warmup=False, multimodal=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
......@@ -67,7 +67,7 @@ def make_dataset(path, impl, skip_warmup=False):
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
return MMapIndexedDataset(path, skip_warmup, multimodal)
print(f"Unknown dataset implementation: {impl}")
return None
......@@ -95,9 +95,9 @@ dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16
6: np.float64,
7: np.float32,
8: np.uint16,
}
......@@ -268,8 +268,8 @@ class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
np.float32: 4,
np.float64: 8,
}
def __init__(self, out_file, dtype=np.int32):
......@@ -365,7 +365,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return pointers
def write(self, sizes, doc_idx):
def write(self, sizes, modes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
......@@ -375,6 +375,11 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._file.write(sizes.tobytes(order='C'))
del sizes
if modes is not None:
modes = np.array(modes, dtype=np.int32)
self._file.write(modes.tobytes(order='C'))
del modes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
......@@ -387,7 +392,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return _Writer()
def __init__(self, path, skip_warmup=False):
def __init__(self, path, skip_warmup=False, multimodal=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
......@@ -400,6 +405,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self.multimodal = multimodal
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
......@@ -417,12 +423,21 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
dtype=np.int32,
count=self._len,
offset=offset)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
self._modes = None
if multimodal:
print_rank_0(" reading modes...")
self._modes = np.frombuffer(
self._bin_buffer,
dtype=np.int8,
count=self._len,
offset=offset + self._sizes.nbytes + self._pointers.nbytes + self._doc_idx.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
......@@ -436,35 +451,40 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def sizes(self):
return self._sizes
@property
def modes(self):
return self._modes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
return self._pointers[i], self._sizes[i], (self._modes[i] if self.multimodal else None)
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
def __init__(self, path, skip_warmup=False, multimodal=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self.multimodal = multimodal
self._do_init(path, skip_warmup)
self._do_init(path, skip_warmup, multimodal)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state, skip_warmup=True)
self._do_init(state, skip_warmup=True, multimodal=False)
def _do_init(self, path, skip_warmup):
def _do_init(self, path, skip_warmup, multimodal):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
self._index = self.Index(index_file_path(self._path), skip_warmup, multimodal)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
......@@ -485,22 +505,23 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, (int, np.integer)):
ptr, size = self._index[idx]
ptr, size, mode = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=size, offset=ptr)
return np_array
return (np_array, mode) if mode is not None else np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
modes = self._index._modes[idx] if self.multimodal else None
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
return (sents, modes) if modes is not None else sents
else:
raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
......@@ -510,18 +531,23 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
ptr, size, mode = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
return (np_array, mode) if mode is not None else np_array
@property
def sizes(self):
return self._index.sizes
@property
def modes(self):
return self._index.modes
@property
def doc_idx(self):
return self._index.doc_idx
......@@ -544,35 +570,48 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
def __init__(self, out_file, dtype=np.int64, multimodal=False):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._multimodal = multimodal
self._sizes = []
self._doc_idx = [0]
self._modes = [] if self._multimodal else None
def add_item(self, tensor):
def add_item(self, tensor, mode=0):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def add_doc(self, tensor, sizes):
if self._multimodal:
self._modes.append(mode)
def add_doc(self, tensor, sizes, modes=None):
np_array = np.array(tensor, dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.extend(sizes)
self._doc_idx.append(len(self._sizes))
if self._multimodal:
self._modes.extend(modes if modes is not None else [0]*sizes)
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
index = MMapIndexedDataset.Index(
index_file_path(another_file),
multimodal=self._multimodal)
assert index.dtype == self._dtype
offset = len(self._sizes)
self._sizes.extend(index.sizes)
self._doc_idx.extend((offset + index.doc_idx)[1:])
if self._multimodal:
self._modes.extend(index.modes)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
......@@ -581,4 +620,4 @@ class MMapIndexedDatasetBuilder(object):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)
index.write(self._sizes, self._modes, self._doc_idx)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from PIL import Image, UnidentifiedImageError
import numpy as np
import io
import torch
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, RandomResizedCrop, Resize
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(img_h, img_w):
return Compose([
ToPILImage(),
RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0), interpolation=BICUBIC),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
class MultiModalDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, indexed_dataset,
num_samples, seq_length, seed, img_h, img_w):
self.name = name
self.indexed_dataset = indexed_dataset
self.doc_idx = indexed_dataset.get_doc_idx()
self.visual_transform = _transform(img_h, img_w)
def __len__(self):
return self.indexed_dataset.sizes.shape[0]
def __getitem__(self, idx):
text_sample, mode = self.indexed_dataset.get(self.doc_idx[idx])
assert mode == 0
img_sample, mode = self.indexed_dataset.get(self.doc_idx[idx]+1)
assert mode == 1
img_pad = img_sample[0].item()
xs = img_sample[1:].tobytes(order='C')
xs = xs[:len(xs)-img_pad]
img_sample = np.array(Image.open(io.BytesIO(xs)))
img_sample = self.visual_transform(img_sample).reshape(-1)
return {'text': np.array(text_sample, dtype=np.int64),
'img': np.array(img_sample, dtype=np.float32)}
......@@ -19,17 +19,18 @@ def load(args):
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
cpp_extension.CUDA_HOME
)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 7:
if int(bare_metal_minor) >= 8:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
buildpath = srcpath / "build"
_create_build_dir(buildpath)
# Helper function to build the kernels.
......@@ -38,46 +39,25 @@ def load(args):
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'--use_fast_math'] + extra_cuda_flags + cc_flag,
verbose=(args.rank == 0)
extra_cflags=[
"-O3",
],
extra_cuda_cflags=[
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"--use_fast_math",
]
+ extra_cuda_flags
+ cc_flag,
verbose=(args.rank == 0),
)
# ==============
# Fused softmax.
# ==============
if args.masked_softmax_fusion:
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources, extra_cuda_flags)
# Masked softmax.
sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
# Softmax
sources=[srcpath / 'scaled_softmax.cpp',
srcpath / 'scaled_softmax_cuda.cu']
scaled_softmax_cuda = _cpp_extention_load_helper(
"scaled_softmax_cuda", sources, extra_cuda_flags)
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
......
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
// compute scale value to account for full mask
acc_t scale_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
}
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] * scale_value[i] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
void* input_grads_ptr = static_cast<void*>(input_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(input_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
return input_grads;
}
}
}
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int stride,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment