Commit 4e867b3c authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F
from megatron.core.jit import jit_fuser
###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################
@jit_fuser
def swiglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return F.silu(y_1) * y_2
@jit_fuser
def bias_swiglu(y, bias):
y = y + bias
return swiglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def swiglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
return torch.cat(
(g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1
)
@jit_fuser
def bias_swiglu_back(g, y, bias):
y = y + bias
return swiglu_back(g, y)
class BiasSwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
ctx.save_for_backward(input_for_backward, bias)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return bias_swiglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = bias_swiglu_back(grad_output, input, bias)
return tmp, tmp, None
class SwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
ctx.save_for_backward(input_for_backward)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return swiglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = swiglu_back(grad_output, input)
return tmp, None
def bias_swiglu_impl(input, bias, fp8_input_store=False):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store)
else:
output = SwiGLUFunction.apply(input, fp8_input_store)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import torch
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from megatron.core.tensor_parallel.utils import VocabUtility
@jit_fuser
def calculate_logits_max(vocab_parallel_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates the maximum logits of the predicted tokens.
"""
vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max(
vocab_parallel_logits
)
return vocab_parallel_logits, logits_max
@jit_fuser
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calculates the predicted logits for the tokens.
"""
(target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = (
VocabParallelCrossEntropy.calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
predicted_logits_sum_exp_logits = torch.cat((predicted_logits, sum_exp_logits))
return target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits
@jit_fuser
def calculate_cross_entropy_loss(
exp_logits: torch.Tensor, predicted_logits_sum_exp_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculates the final cross entropy loss for the tokens.
"""
split_val = predicted_logits_sum_exp_logits.size()[0] // 2
predicted_logits, sum_exp_logits = torch.split(predicted_logits_sum_exp_logits, split_val)
exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss(
exp_logits, predicted_logits, sum_exp_logits
)
return exp_logits, loss
@jit_fuser
def calculate_gradients(
softmax: torch.Tensor,
grad_output: torch.Tensor,
target_mask: torch.Tensor,
masked_target_1d: torch.Tensor,
) -> torch.Tensor:
"""
Calculate the logits gradients scaled based on the CE loss
"""
(grad_2d, arange_1d, softmax_update, grad_input) = (
VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
)
grad_input = VocabParallelCrossEntropy.calculate_gradients(
grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
)
grad_input = grad_input.to(torch.bfloat16)
return grad_input
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target, tp_group):
"""
Forward implementation for the cross entropy loss.
"""
vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits)
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=tp_group)
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, tp_group.rank(), tp_group.size()
)
(target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits) = (
calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
# All reduce is needed to get the chunks from other GPUs.
# In the fused case, tensors are batches to invoke a single
# AllReduce call
torch.distributed.all_reduce(
predicted_logits_sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group
)
exp_logits, loss = calculate_cross_entropy_loss(exp_logits, predicted_logits_sum_exp_logits)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
"""
Backward implementation for the cross entropy loss.
"""
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
grad_input = calculate_gradients(softmax, grad_output, target_mask, masked_target_1d)
return grad_input, None, None
def fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
tp_group: the tensor parallel group over which to all reduce
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, tp_group)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
import inspect
import numbers
import torch
from torch import Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from megatron.core.transformer import TransformerConfig
from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except ImportError:
HAVE_PERSIST_LAYER_NORM = False
try:
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM = True
except ImportError:
HAVE_FUSED_LAYER_NORM = False
class FusedLayerNorm(torch.nn.Module):
"""Layer Norm, fused into a single CUDA kernel.
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def __init__(
self,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
persist_layer_norm: bool = True,
zero_centered_gamma: bool = False,
normalization: str = "LayerNorm", # included to match TE interface
):
super().__init__()
self.config = config
self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma
assert (
self.config.normalization == "LayerNorm"
), f'({self.config.normalization}) is not supported in FusedLayerNorm'
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [
1024,
1536,
2048,
2304,
3072,
3840,
4096,
5120,
6144,
8192,
10240,
12288,
12800,
15360,
16384,
18432,
20480,
24576,
25600,
30720,
32768,
40960,
49152,
65536,
]
persist_layer_norm = self.config.persist_layer_norm
if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
persist_layer_norm = False
if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM:
# TODO: Add pytorch only layer norm
raise ValueError(f'Apex must be installed to use FusedLayerNorm.')
if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = eps
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self.weight = Parameter(torch.empty(*hidden_size))
self.bias = Parameter(torch.empty(*hidden_size))
self.reset_parameters()
self.persist_layer_norm = persist_layer_norm
self.sequence_parallel = self.config.sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self):
if self.zero_centered_gamma:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input: Tensor) -> Tensor:
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
if self.persist_layer_norm:
if 'memory_efficient' in inspect.getfullargspec(FastLayerNormFN.forward).args:
output = FastLayerNormFN.apply(
input, weight, self.bias, self.eps, self.config.memory_efficient_layer_norm
)
else:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(
inp=output, requires_grad=input.requires_grad, keep_graph=True
)
else:
if (
'memory_efficient'
in inspect.getfullargspec(FusedLayerNormAffineFunction.forward).args
):
return FusedLayerNormAffineFunction.apply(
input,
weight,
self.bias,
self.hidden_size,
self.eps,
self.config.memory_efficient_layer_norm,
)
else:
return FusedLayerNormAffineFunction.apply(
input, weight, self.bias, self.hidden_size, self.eps
)
return output
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch
import torch.nn as nn
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.utils import get_default_causal_mask
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Args:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor]):
"""Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed.
A user-defined mask is only needed when attn_mask_type is not causal.
"""
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
return ScaledMaskedSoftmax.apply(input, mask, scale)
else:
return ScaledSoftmax.apply(input, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
# Generate causal mask if not given
sq, sk = input.size(2), input.size(3)
if self.attn_mask_type == AttnMaskType.causal and mask is None and sq > 1:
# If sq == 1 then either KV cache is used or one-element context is passed
# so keeping mask=None in this case; subsequent code should handle it
assert sq == sk, "causal mask is only for self attention"
mask = get_default_causal_mask(sq)
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright 2025 The vLLM authors.
#
# This code was adopted from https://github.com/vllm-project/vllm/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from megatron.core.inference.inference_request import InferenceRequest
STOP_ITERATION = Exception()
class AsyncStream:
"""
Class for encapsulating an asynchronous stream of InferenceRequest outputs.
Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long
"""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self._request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
self._loop = asyncio.get_running_loop()
def put(self, item: Union[InferenceRequest, Exception]) -> None:
"""Adds a new value to the stream"""
if not self._finished:
self._loop.call_soon_threadsafe(self._queue.put_nowait, item)
def finish(self, exception: Optional[Union[BaseException, Type[BaseException]]] = None) -> None:
"""Completes the stream by adding a sentinel value"""
if not self._finished:
self._finished = True
self._loop.call_soon_threadsafe(
self._queue.put_nowait,
exception if self._is_raisable(exception) else STOP_ITERATION,
)
@property
def finished(self) -> bool:
"""Whether the stream has finished"""
return self._finished
async def generator(self) -> AsyncGenerator[InferenceRequest, None]:
"""Creates an AsyncGenerator over the stream queue"""
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel()
raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or (
isinstance(value, type) and issubclass(value, BaseException)
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.sampling_params import ( # noqa: F401 # pylint: disable=unused-import
SamplingParams as CommonInferenceParams,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment