Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -4,7 +4,7 @@
"""Context Parallelism."""
import os
from typing import List, Union
from typing import List, Union, Tuple
import torch
import transformer_engine_torch as tex
......@@ -358,7 +358,7 @@ def get_fa_args(
max_seqlen_q,
max_seqlen_kv,
*[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
* 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale
]
return [
*[None]
......@@ -366,7 +366,7 @@ def get_fa_args(
max_seqlen_q,
max_seqlen_kv,
*[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
* 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale
]
if qkv_format == "thd":
return [
......@@ -829,6 +829,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
......@@ -838,19 +851,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
k_part,
v_part,
*fa_forward_args_thd,
causal=True,
**fa_forward_kwargs,
......@@ -985,6 +989,22 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
if enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
......@@ -1001,19 +1021,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
k_part,
v_part,
*fa_forward_args_thd,
causal=False,
**fa_forward_kwargs,
......@@ -1144,6 +1155,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
......@@ -1160,19 +1184,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
k_part,
v_part,
*fa_forward_args_thd,
causal=False,
**fa_forward_kwargs,
......@@ -1269,6 +1284,19 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = (
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
)
v_part = (
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
)
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
......@@ -1278,19 +1306,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd(
q,
(
kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][0]
),
(
kv_inputs[i % 2][..., 1, :, :]
if qkv_format in ["bshd", "sbhd"]
else kv_inputs[i % 2][1]
),
k_part,
v_part,
*fa_forward_args_thd,
causal=False,
**fa_forward_kwargs,
......@@ -1865,7 +1884,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_ = dv_._data
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
if ctx.enable_mla:
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
dkv_ = torch.empty_like(kv_)
dk_ = (
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
)
dv_ = (
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
)
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
......@@ -1875,16 +1914,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
dk=dk_,
dv=dv_,
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
......@@ -1895,12 +1926,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = 0
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout_,
q_,
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
k_part,
v_part,
out_,
softmax_lse,
*fa_backward_args_thd,
......@@ -2016,7 +2046,29 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_ = dv_._data
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
if ctx.enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
dkv_ = torch.empty_like(kv_)
dk_ = (
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
)
dv_ = (
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
)
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
......@@ -2026,16 +2078,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv // 2,
dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
dk=dk_,
dv=dv_,
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
......@@ -2046,12 +2090,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout_,
q_,
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
k_part,
v_part,
out_,
softmax_lse,
*fa_backward_args_thd,
......@@ -2160,7 +2203,27 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dv_ = dv_._data
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
if ctx.enable_mla:
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
dkv_ = torch.empty_like(kv_)
dk_ = (
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
)
dv_ = (
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
)
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
......@@ -2170,16 +2233,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q // 2,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
dk=dk_,
dv=dv_,
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
......@@ -2190,12 +2245,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout_,
q_,
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0],
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
k_part,
v_part,
out_,
softmax_lse_,
*fa_backward_args_thd,
......@@ -2267,7 +2321,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dq_ = torch.empty_like(q)
dkv_ = torch.empty_like(kv)
if ctx.enable_mla:
dk_ = torch.empty_like(k_part)
dv_ = torch.empty_like(v_part)
else:
k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
dkv_ = torch.empty_like(kv)
dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0]
dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1]
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
......@@ -2277,8 +2339,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_,
dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
dk=dk_,
dv=dv_,
)
if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1)
......@@ -2287,12 +2349,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd(
dout,
q,
kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0],
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
k_part,
v_part,
out,
softmax_lse,
*fa_backward_args_thd,
......@@ -3927,3 +3988,212 @@ def attn_forward_func_with_cp(
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
return out
def pad_thd_sequences_for_cp(
input_ids: torch.Tensor,
labels: torch.Tensor,
cu_seqlens: torch.Tensor,
divisibility_factor: int,
padding_token_id: int = 0,
padding_label_id: int = -100,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Pads sequences to be divisible by the divisibility factor.
Args:
input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences
labels: Tensor of shape (1, N) or (N,) containing labels for each token
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
padding_token_id: Token ID to use for padding (default: 0)
padding_label_id: Label ID to use for padding (default: -100)
Returns:
Tuple of:
- input_ids_padded: Padded input_ids tensor
- labels_padded: Padded labels tensor
- cu_seqlens_padded: Cumulative sequence lengths accounting for padding
"""
# Flatten input_ids and labels if needed
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
if labels.dim() == 2:
labels = labels.squeeze(0)
# Compute the sequence lengths from cu_seqlens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# List: amount of padding needed for each sequence (make length a multiple of divisibility_factor)
padding_amounts = [
((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor
- l.item()
for l in seqlens
]
# Extract sequences and labels for each batch item
batch_sequences = [
input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])
]
batch_labels = [
labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])
]
# Pad sequences and labels to required length
input_ids_padded = torch.cat(
[
(
torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)])
if pad > 0
else seq
)
for seq, pad in zip(batch_sequences, padding_amounts)
]
)
labels_padded = torch.cat(
[
(
torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)])
if pad > 0
else seq
)
for seq, pad in zip(batch_labels, padding_amounts)
]
)
# Compute cumulative padded sequence lengths, starting from 0
padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype)
cu_seqlens_padded = torch.cumsum(
torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0
)
return input_ids_padded, labels_padded, cu_seqlens_padded
def generate_positional_ids_for_cp(
cu_seqlens: torch.Tensor,
divisibility_factor: int,
dtype: torch.dtype = torch.long,
) -> torch.Tensor:
"""Generate positional IDs for sequences padded to be divisible by divisibility_factor.
Args:
cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths
divisibility_factor: Each sequence length must be divisible by this factor
dtype: Data type for the generated positional IDs (default: torch.long)
Returns:
Generated positional_ids tensor where each sequence starts from 0 and continues through padding
"""
# Compute the sequence lengths from cu_seqlens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# List: amount of padding needed for each sequence
padding_amounts = [
((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor
- l.item()
for l in seqlens
]
# Generate positional IDs for each padded sequence (each starts from 0)
padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype)
positional_ids = torch.cat(
[torch.arange(0, int(length), dtype=dtype) for length in padded_lengths]
)
return positional_ids
def get_batch_on_this_cp_rank(
cu_seqlens_padded: torch.Tensor,
input_ids_padded: torch.Tensor,
labels_padded: torch.Tensor,
position_ids_padded: torch.Tensor,
cp_group: torch.distributed.ProcessGroup = None,
qvk_format: str = "thd",
):
"""Slice batch input along sequence dimension into multiple chunks for THD format.
This function is inteded for use in self attention. It will not work for cross attention because
it does not handle the case where the sequence length of the query and key are different.
Which are parallelized across GPUs in a context parallel group.
This version works with variable-length sequences using cumulative sequence lengths.
"""
if qvk_format not in ["thd", "bshd", "sbhd"]:
raise ValueError(f"Unsupported qvk_format: {qvk_format}!")
if qvk_format == "thd":
# Get context parallel size and rank
cp_size = torch.distributed.get_world_size(group=cp_group)
if cp_size > 1:
cp_rank = torch.distributed.get_rank(group=cp_group)
# Calculate the chunk sizes for each sequence
total_slices_of_any_sequence = 2 * cp_size
slice_sizes = (
cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]
) // total_slices_of_any_sequence
# Process each tensor directly instead of using keys_to_change loop
def process_tensor(val):
if val is None:
return val
# Determine which dimension is the sequence dimension
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
if isinstance(cu_seqlens_padded[-1], torch.Tensor):
seq_len_val = cu_seqlens_padded[-1].item()
else:
seq_len_val = cu_seqlens_padded[-1]
# Handle 1D tensors (like position_ids that don't have batch dimension)
if val.ndim == 1:
if val.shape[0] == seq_len_val:
current_seq_dim = 0
else:
raise ValueError(
"1D tensor shape doesn't match expected sequence length. Make sure the"
" inputs are in THD format and padded correctly."
)
elif val.ndim >= 2:
if val.shape[1] == seq_len_val:
current_seq_dim = 1
elif val.shape[0] == seq_len_val:
current_seq_dim = 0
else:
raise ValueError(
"Make sure the inputs are in THD format and padded correctly."
)
else:
raise ValueError("Tensor must be at least 1D")
# On this particular rank, for each sequence, get two slices, one from the beginning
# and one from the end.
cp_rank_slices = []
for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]):
# 1st segment
cp_rank_slices.append(
torch.arange(
seq_start + (cp_rank * slice_size),
seq_start + ((cp_rank + 1) * slice_size),
device=val.device,
)
)
# 2nd segment
cp_rank_slices.append(
torch.arange(
seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size),
seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size),
device=val.device,
)
)
return val.index_select(current_seq_dim, torch.cat(cp_rank_slices))
# Process each tensor directly
input_ids_padded = process_tensor(input_ids_padded)
labels_padded = process_tensor(labels_padded)
position_ids_padded = process_tensor(position_ids_padded)
else:
raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!")
return input_ids_padded, labels_padded, position_ids_padded
......@@ -126,10 +126,10 @@ class FlashAttentionUtils:
# Please follow these instructions to install FA3
v3_installation_steps = """\
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py"""
(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
v3_warning_printed = False
@staticmethod
......@@ -438,8 +438,10 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12")
# Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version
# until the cuDNN bug is resolved
if device_compute_capability == (8, 9):
logger.debug("Disabling FusedAttention for KV caching for sm89")
use_fused_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
......@@ -482,11 +484,10 @@ def get_attention_backend(
# Filter: Head dimension
if not IS_HIP_EXTENSION:
if head_dim_qk != head_dim_v:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
use_flash_attention_2 = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
......@@ -518,10 +519,41 @@ def get_attention_backend(
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128):
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128")
use_flash_attention_3 = False
if use_flash_attention_3:
def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype):
if head_dim_qk > 256 or num_heads % num_gqa_groups != 0:
return False
if head_dim_qk != head_dim_v:
cond1 = 128 < head_dim_qk <= 192
cond2 = 96 < head_dim_v <= 128
cond3 = head_dim_qk <= 64 and head_dim_v <= 512
if not ((cond1 and cond2) or cond3):
return False
if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16):
return False
return True
if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype):
if FlashAttentionUtils.v3_is_installed:
logger.debug(
"Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, "
"head_dim_qk, head_dim_v or qkv_dtype. "
"Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and "
"if head_dim_qk is different from head_dim_v, then "
"(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or "
"(head_dim_qk <= 64 and head_dim_v <= 512), and "
"if head_dim_qk is different from head_dim_v and head_dim_v > 256, then "
"qkv_dtype requires fp16 and bf16 data type. "
"Found: num_heads = %s, num_gqa_groups = %s, "
"head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.",
num_heads,
num_gqa_groups,
head_dim_qk,
head_dim_v,
qkv_dtype,
)
use_flash_attention_3 = False
# Filter: QKV layout
if qkv_format == "thd":
......@@ -838,7 +870,7 @@ def get_attention_backend(
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes;
# sub-backend 1 | workspace optimization path and sm90: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
......@@ -854,8 +886,9 @@ def get_attention_backend(
use_flash_attention_2 = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
use_fused_attention = False
fused_attention_backend = None
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
......@@ -865,8 +898,13 @@ def get_attention_backend(
or cudnn_version < (8, 9, 5)
)
):
logger.debug("Disabling FusedAttention for determinism reasons")
logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
use_fused_attention = False
fused_attention_backend = None
if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0):
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
use_fused_attention = False
fused_attention_backend = None
# use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2
......
......@@ -215,6 +215,17 @@ class InferenceParams:
device=torch.cuda.current_device(),
)
# This internal buffer holds the running length of each
# unfinished sequence in the batch and is updated in `pre_step()`
# method. One use of this buffer is applying RoPE to q and k tensors
# during inference by slicing ROPE Embeddings according to the
# current sequence length window.
self.pre_step_seqlens = torch.zeros(
self.max_batch_size,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def reset(self):
"""Reset InferenceParams state"""
self.sequences = OrderedDict()
......@@ -266,6 +277,15 @@ class InferenceParams:
for k, v in self.sequences.items():
self.sequences_pre_step[k] = v - step_dict[k]
pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
# Copy the pre-step seqlens to the device in CUDA Graphs safe manner.
self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_(
pre_step_seqlens_temp, non_blocking=False
)
seqlens_q = list(step_dict.values())
cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size)
......@@ -280,9 +300,7 @@ class InferenceParams:
def get_seqlens_pre_step(self):
"""Get cached sequence lengths before the stepping"""
return torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
return self.pre_step_seqlens
def convert_paged_to_nonpaged(self, layer_number: int):
"""
......@@ -458,14 +476,14 @@ class NonPagedKVCacheManager(KVCacheManager):
finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs]
self.batch_indices.copy_(
self.batch_indices.data[:].copy_(
torch.Tensor(
(
unfinished_indices
+ finished_indices
+ list(range(prev_batch_size, self.max_batch_size))
)
).to(dtype=torch.int32, device="cpu")
)
)
# Advance unfinished sequences
......
......@@ -889,23 +889,11 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb, k_pos_emb = rotary_pos_emb
# adjust key and value for inference
if inference_params is not None:
if self.qkv_format == "sbhd":
sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)
else:
raise ValueError(
f"qkv_format={self.qkv_format} not supported for KV caching and RoPE."
)
sequence_start = inference_params.get_seqlens_pre_step()
# sequence_start = inference_params.seqlens[0]
sequence_end = sequence_start + sequence_length
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
# Applyig RoPE for inference needs start positions of sequences
# for each iteration.
sequence_start_positions = (
inference_params.get_seqlens_pre_step() if inference_params is not None else None
)
if pad_between_seqs:
rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded
......@@ -922,6 +910,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=rotary_pos_cu_seq_lens_q,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
start_positions=sequence_start_positions,
interleaved=self.rotary_pos_interleaved,
)
key_layer = apply_rotary_pos_emb(
......@@ -932,6 +921,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=rotary_pos_cu_seq_lens_kv,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
start_positions=sequence_start_positions,
interleaved=self.rotary_pos_interleaved,
)
......
......@@ -5,14 +5,14 @@
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module):
......@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function):
return grad_input, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function):
"""
Function for FusedQKVRoPE
This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
"""
@staticmethod
def forward(
ctx,
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""Fused RoPE forward."""
if q_freqs.dtype != torch.float32:
q_freqs = q_freqs.float()
if k_freqs.dtype != torch.float32:
k_freqs = k_freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
), f"Unsupported tensor_format: {tensor_format}."
assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
output = tex.fused_qkv_rope_forward(
qkv,
q_freqs,
k_freqs,
start_positions,
qkv_split_arg_list,
QKVFormat[tensor_format],
interleaved,
cp_size,
cp_rank,
)
ctx.save_for_backward(q_freqs, k_freqs)
ctx.tensor_format = tensor_format
ctx.qkv_split_arg_list = qkv_split_arg_list
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(
ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
q_freqs, k_freqs = ctx.saved_tensors
grad_output_q = grad_output_q.contiguous()
grad_output_k = grad_output_k.contiguous()
grad_output_v = grad_output_v.contiguous()
grad_input = tex.fused_qkv_rope_backward(
grad_output_q,
grad_output_k,
grad_output_v,
q_freqs,
k_freqs,
ctx.qkv_split_arg_list,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
ctx.cp_size,
ctx.cp_rank,
)
return grad_input, None, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
......@@ -393,3 +473,82 @@ def apply_rotary_pos_emb(
tensor_format,
interleaved=interleaved,
)
def apply_fused_qkv_rotary_pos_emb(
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused:
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
Context parallel rank.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert tensor_format != "thd", "'thd' tensor_format not supported currently."
return FusedQKVRoPEFunc.apply(
qkv,
q_freqs,
k_freqs,
qkv_split_arg_list,
start_positions,
tensor_format,
interleaved,
cp_size,
cp_rank,
)
......@@ -559,17 +559,23 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
main_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.h2d_stream):
# move back tensors
for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label
if group_id == group_to_reload:
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
reload_buffer = None
if isinstance(state, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state[1], device=torch.cuda.current_device()
)
recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, reload_buffer
)
......@@ -578,14 +584,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
elif isinstance(state, list):
tensor_list = []
for state_tuple in state:
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
reload_buffer = None
if isinstance(state_tuple, tuple):
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
with torch.cuda.stream(main_stream):
reload_buffer = torch.empty_like(
state_tuple[1], device=torch.cuda.current_device()
)
tensor_list.append(
SynchronizedGroupOffloadHandler.reload(
state_tuple,
......
......@@ -190,38 +190,49 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
* Activations
**************************************************************************************************/
/* GELU and variants*/
py::object gelu(const at::Tensor &input, py::handle quantizer);
py::object relu(const at::Tensor &input, py::handle quantizer);
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object geglu(const at::Tensor &input, py::handle quantizer);
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object reglu(const at::Tensor &input, py::handle quantizer);
py::object qgelu(const at::Tensor &input, py::handle quantizer);
py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object qgelu(const at::Tensor &input, py::handle quantizer);
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
py::object srelu(const at::Tensor &input, py::handle quantizer);
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/* ReLU and variants*/
py::object relu(const at::Tensor &input, py::handle quantizer);
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object reglu(const at::Tensor &input, py::handle quantizer);
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object srelu(const at::Tensor &input, py::handle quantizer);
py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object sreglu(const at::Tensor &input, py::handle quantizer);
py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/* Silu and variants*/
py::object silu(const at::Tensor &input, py::handle quantizer);
py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
......@@ -244,6 +255,11 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma);
std::vector<py::object> rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &add, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma);
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer, DType otype,
const int sm_margin, const bool zero_centered_gamma);
......@@ -285,6 +301,17 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer);
/***************************************************************************************************
* Dropout
**************************************************************************************************/
std::vector<py::object> dropout_fwd(const py::handle &input, const float dropout_probability,
std::optional<at::Tensor> out = std::nullopt);
py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
const float dropout_probability,
std::optional<at::Tensor> grad_input = std::nullopt);
/***************************************************************************************************
* Softmax
**************************************************************************************************/
......@@ -349,6 +376,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank);
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank);
/***************************************************************************************************
* Miscellaneous
**************************************************************************************************/
......
......@@ -101,6 +101,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
return grad_input_py;
}
/* GELU and variants*/
py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_gelu>(input, quantizer);
}
......@@ -109,30 +110,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
return dactivation_helper<nvte_dgelu>(grad, input, quantizer);
}
py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer);
py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2);
}
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
}
py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_geglu>(input, quantizer, 2);
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer);
}
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2);
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
}
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgeglu>(input, quantizer, 2);
}
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer);
}
/* ReLU and variants*/
py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_relu>(input, quantizer);
}
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
}
py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_reglu>(input, quantizer, 2);
}
......@@ -141,28 +151,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu
return dactivation_helper<nvte_dreglu>(grad, input, quantizer);
}
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2);
py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer);
}
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
}
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_qgelu>(input, quantizer);
py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_sreglu>(input, quantizer, 2);
}
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer);
}
py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_srelu>(input, quantizer);
/* Silu and variants*/
py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_silu>(input, quantizer);
}
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dsilu>(grad, input, quantizer);
}
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_helper<nvte_swiglu>(input, quantizer, 2);
}
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
}
} // namespace transformer_engine::pytorch
......@@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
auto start_positions_cu = TensorWrapper(); // empty start_positions tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor");
}
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
......@@ -102,6 +103,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
return output;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank) {
TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output
auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device());
auto q_out_size = qkv_input.sizes().vec();
q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1];
q_out_size[3] = qkv_split_arg_list[1];
auto q_out = at::empty(q_out_size, act_options);
auto k_out_size = qkv_input.sizes().vec();
k_out_size[3] = qkv_split_arg_list[1];
auto k_out = at::empty(k_out_size, act_options);
auto v_out_size = qkv_input.sizes().vec();
v_out_size[3] = qkv_split_arg_list[2];
auto v_out = at::empty(v_out_size, act_options);
auto qkv_cu = makeTransformerEngineTensor(qkv_input);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto q_out_cu = makeTransformerEngineTensor(q_out);
auto k_out_cu = makeTransformerEngineTensor(k_out);
auto v_out_cu = makeTransformerEngineTensor(v_out);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor");
TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous");
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1);
const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0);
const int h = qkv_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(),
start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(),
v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h,
d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1],
qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream());
return std::make_tuple(q_out, k_out, v_out);
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
......@@ -193,4 +253,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads;
}
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank) {
auto act_options =
at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device());
auto qkv_grad_size = q_grad_out.sizes().vec();
auto total_hd =
(q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3);
auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2];
qkv_grad_size[2] = total_hd / total_d;
qkv_grad_size[3] = total_d;
auto qkv_grad_input = at::empty(qkv_grad_size, act_options);
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1);
const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0);
const int h = qkv_grad_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out);
auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out);
auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input);
nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(),
q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0],
qkv_split_arg_list[1], qkv_split_arg_list[2],
at::cuda::getCurrentCUDAStream());
return qkv_grad_input;
}
} // namespace transformer_engine::pytorch
......@@ -205,11 +205,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if (buffer->data_ptr<uint8_t>() == nullptr) {
bool is_empty_shape = product(shape) == 0;
if (buffer->data_ptr<uint8_t>() == nullptr || is_empty_shape) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
......@@ -359,11 +356,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if (buffer->data_ptr<uint8_t>() == nullptr) {
bool is_empty_shape = product(shape) == 0;
if (buffer->data_ptr<uint8_t>() == nullptr || is_empty_shape) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/dropout.h"
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <pybind.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include "../common.h"
#include "../extensions.h"
#include "../pybind.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace pytorch {
std::vector<py::object> dropout_fwd(const py::handle &input, float dropout_probability,
std::optional<at::Tensor> out) {
using namespace transformer_engine::pytorch::detail;
// Input tensor
const TensorWrapper input_nvte = makeTransformerEngineTensor(input, py::none());
// Allocate output tensor if needed
if (!out) {
at::ScalarType dtype = GetATenDType(input_nvte.dtype());
if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2) {
dtype = input.attr("dtype").cast<at::ScalarType>();
}
const auto shape_uint64 = convertShape(input_nvte.shape());
const std::vector<int64_t> shape_int64(shape_uint64.begin(), shape_uint64.end());
const auto opts = at::TensorOptions().dtype(dtype).device(torch::kCUDA);
out = at::empty(shape_int64, opts);
}
TensorWrapper out_nvte = makeTransformerEngineTensor(*out);
// Mask tensor
auto mask_pyt = allocateTorchTensor(input_nvte.numel() / 8, DType::kByte);
auto mask_nvte = makeTransformerEngineTensor(mask_pyt);
// RNG state tensor
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
constexpr int64_t rng_elts_per_thread = 4;
philox_args = gen->philox_cuda_state(rng_elts_per_thread);
}
auto rng_state_pyt = allocateTorchTensor(2, DType::kInt64);
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(
reinterpret_cast<int64_t *>(rng_state_pyt.data_ptr()), philox_args.captured_,
philox_args.seed_.ptr, philox_args.seed_.val, philox_args.offset_.ptr,
philox_args.offset_.val, philox_args.offset_intragraph_, at::cuda::getCurrentCUDAStream());
});
auto rng_state_nvte = makeTransformerEngineTensor(rng_state_pyt);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_dropout_fwd(input_nvte.data(), out_nvte.data(), mask_nvte.data(), rng_state_nvte.data(),
dropout_probability, at::cuda::getCurrentCUDAStream());
});
return {py::cast(std::move(*out)), py::cast(mask_pyt)};
}
py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
const float dropout_probability, std::optional<at::Tensor> grad_input) {
const auto grad_output_nvte = makeTransformerEngineTensor(grad_output);
const auto mask_nvte = makeTransformerEngineTensor(mask);
if (!grad_input) {
grad_input = at::empty_like(grad_output);
}
auto grad_input_nvte = makeTransformerEngineTensor(*grad_input);
NVTE_SCOPED_GIL_RELEASE({
nvte_dropout_bwd(grad_output_nvte.data(), mask_nvte.data(), grad_input_nvte.data(),
dropout_probability, at::cuda::getCurrentCUDAStream());
});
return py::cast(std::move(*grad_input));
}
} // namespace pytorch
} // namespace transformer_engine
......@@ -95,6 +95,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap, float alpha, std::optional<float> beta) {
using namespace transformer_engine::pytorch::detail;
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
......@@ -125,10 +127,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
"into D tensor. Beta has nothing to be applied to.");
}
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer);
} else {
D_tensor = makeTransformerEngineTensor(D, quantizer);
......@@ -141,12 +143,35 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
}
// maintain unquantized tensor in case we need unfused quantization support.
TensorWrapper unquantized_D_tensor;
py::object unquantized_out;
// Unfused quantization is needed in the following cases
// 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that)
// 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling,
// GEMM Output needs to be in BF16, to allow for unfused quantization)
bool unfused_quantization_needed = !quantizer.is_none();
if (low_precision) {
// At the moment, only use-case for fused GEMM:
// Delayed scaling quantizer with per-tensor scaling inputs
bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr());
if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input)
unfused_quantization_needed = false;
}
if (unfused_quantization_needed) {
NoneQuantizer q{none};
std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype);
}
TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor;
// Bias tensor
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
if (grad) {
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
auto opts =
torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA);
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
bias_tensor = makeTransformerEngineTensor(*bias_grad);
} else {
......@@ -159,7 +184,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
DType gelu_type = low_precision ? bias_type : out_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
......@@ -212,7 +237,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap
if (bulk_overlap) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, comm_type.value(), extra_output_tensor,
main_stream);
......@@ -220,14 +245,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else if (comm_type.value() == CommOverlapType::AG) {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
......@@ -236,14 +261,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
......@@ -253,15 +278,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(),
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(),
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
te_workspace.data(), alpha, *beta, use_split_accumulator,
num_math_sms, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
D_tensor.zero_(main_stream);
if (out_tensor.numel() != 0 && !accumulate) {
out_tensor.zero_(main_stream);
}
if (bias.has_value()) {
if (bias->numel() != 0 && grad) {
......@@ -269,7 +294,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
}
}
if (unfused_quantization_needed) {
// Quantize the output
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
my_quantizer->quantize(unquantized_D_tensor, D_tensor);
}
// Pack outputs
std::vector<py::object> out;
out.emplace_back(std::move(D));
......@@ -449,24 +478,12 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
// For now, we only have multi-stream cublas backend.
const char *NVTE_USE_HIPBLASLT_GROUPEDGEMM = std::getenv("NVTE_USE_HIPBLASLT_GROUPEDGEMM");
if(NVTE_USE_HIPBLASLT_GROUPEDGEMM != nullptr && NVTE_USE_HIPBLASLT_GROUPEDGEMM[0] == '1'){
NVTE_SCOPED_GIL_RELEASE({
nvte_grouped_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
} else {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
}
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(),
transa, transb, grad, te_workspace_vector.data(), accumulate,
use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream());
});
return bias;
}
......
......@@ -110,7 +110,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
......@@ -145,7 +146,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
......@@ -199,6 +201,52 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
return {py::cast(dx), py::cast(dgamma)};
}
std::vector<py::object> rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &add, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) {
const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous();
const auto &add_ = add.contiguous();
const auto &rsigma_ = rsigma.contiguous();
const auto &gamma_ = gamma.contiguous();
auto dx = at::empty_like(x_);
auto dgamma = at::empty_like(gamma_);
TensorWrapper workspace;
auto dz_cu = makeTransformerEngineTensor(dz_);
auto x_cu = makeTransformerEngineTensor(x_);
auto add_cu = makeTransformerEngineTensor(add_);
auto rsigma_cu = makeTransformerEngineTensor(rsigma_);
auto gamma_cu = makeTransformerEngineTensor(gamma_);
auto dx_cu = makeTransformerEngineTensor(dx);
auto dgamma_cu = makeTransformerEngineTensor(dgamma);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(),
gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
// Alloc space for Tensors.
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE({
nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(),
gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
});
return {py::cast(dx), py::cast(dgamma)};
}
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object out, py::handle quantizer, DType out_dtype,
const int sm_margin, const bool zero_centered_gamma) {
......@@ -244,7 +292,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper unquantized_out_cu;
py::object unquantized_out;
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
......@@ -279,7 +328,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else {
......
......@@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
/* GELU and variants*/
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"),
py::arg("quantizer"));
/* ReLU and variants */
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
/* SwiGLU and variants */
m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer"));
/* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of ReLU and variants */
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
/* Backward of SiLU and variants */
m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
/* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize",
......@@ -202,6 +217,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add,
"Fused backward of RMSNorm + add");
m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
......@@ -281,6 +298,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward,
"Fused Apply QKV RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward,
"Fused Apply QKV RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router
m.def("fused_topk_with_score_function_fwd",
......@@ -308,6 +329,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"),
py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
// Dropout
m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG",
py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt);
m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG",
py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"),
py::arg("grad_input") = std::nullopt);
// Misc
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>());
......
......@@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
......@@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
// quantize output and its transpose
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
......@@ -518,7 +497,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te
// Compute amax
if (compute_amax) {
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); });
}
// Perform amax reduction if needed
......@@ -561,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype) const {
......@@ -916,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize
this->dtype = quantizer.attr("dtype").cast<DType>();
}
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
......
......@@ -4,6 +4,8 @@
"""Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
import contextlib
import gc
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
......@@ -58,6 +60,25 @@ def graph_pool_handle():
return _graph_pool_handle()
@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
This wrapper is a temporary workaround for a PyTorch bug:
automatic garbage collection can destroy a graph while another
graph is being captured, resulting in a CUDA error. See
https://github.com/pytorch/pytorch/pull/161037.
"""
gc_is_enabled = gc.isenabled()
if gc_is_enabled:
gc.disable()
with torch.cuda.graph(*args, **kwargs):
yield
if gc_is_enabled:
gc.enable()
def _make_graphed_callables(
callables: SingleOrTuple[Callable],
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
......@@ -445,7 +466,7 @@ def _make_graphed_callables(
args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
......@@ -483,7 +504,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -548,7 +569,7 @@ def _make_graphed_callables(
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
graph_callables[graph_id] = func
graph_id += 1
......@@ -570,7 +591,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -829,7 +850,7 @@ def make_graphed_callables(
num_warmup_iters: int = 3,
allow_unused_input: bool = False,
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False,
fp8_enabled: SingleOrTuple[bool] = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
......@@ -875,8 +896,9 @@ def make_graphed_callables(
FP8-related parameters
----------------------
fp8_enabled: bool, default = `True`
whether or not to enable fp8
fp8_enabled: (tuple of) bool, default = `False`
whether or not to enable fp8.
If tuple, the length must match the number of modules.
fp8_calibrating: bool, default = `False`
calibration mode allows collecting statistics such as amax and scale
data of fp8 tensors even when executing without fp8 enabled. This is
......@@ -898,17 +920,25 @@ def make_graphed_callables(
"""
set_capture_start()
if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module.
just_one_callable = False
if not isinstance(modules, tuple):
just_one_callable = True
modules = (modules,)
if not isinstance(fp8_enabled, tuple):
assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools"
fp8_enabled = (fp8_enabled,) * len(modules)
else:
assert len(fp8_enabled) == len(
modules
), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})"
if any(fp8_enabled) and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not any(fp8_enabled):
fp8_recipe = None
module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled))
# Store FP8 tensors to reset later.
saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
......@@ -923,15 +953,15 @@ def make_graphed_callables(
old_call_funcs[block_cls] = block_cls.__call__
# Wrap the original call function of the module class.
def call_func(*args, **kwargs):
def call_func(self, *args, **kwargs):
with fp8_autocast(
enabled=fp8_enabled,
enabled=module_uses_fp8.get(id(self), False),
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=True,
):
outputs = old_call_funcs[block_cls](*args, **kwargs)
outputs = old_call_funcs[block_cls](self, *args, **kwargs)
return outputs
block_cls.__call__ = call_func
......
......@@ -12,4 +12,4 @@ from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
from .fp8_padding import Fp8Padding
from .fp8_unpadding import Fp8Unpadding
from .base import initialize_ub, destroy_ub
from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode
......@@ -8,6 +8,7 @@ import math
import os
import pickle
import warnings
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
......@@ -50,7 +51,7 @@ from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTe
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["initialize_ub", "destroy_ub"]
__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
......@@ -66,6 +67,15 @@ _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
class UserBufferQuantizationMode(Enum):
"""
UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
"""
NONE = "none"
FP8 = "fp8"
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
......@@ -134,8 +144,9 @@ def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
quantization_modes: List[UserBufferQuantizationMode] = None,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None,
ub_cfgs: Optional[Union[dict, List[dict]]] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
r"""
......@@ -151,7 +162,11 @@ def initialize_ub(
tp_size : int
number of GPUs in the tensor-parallel process group
use_fp8 : bool = False
allocate the communication buffer for FP8 GEMM inputs/outputs
allocate the communication buffer for FP8 GEMM inputs/outputs.
DEPRECATED: Please use `quantization_modes` instead.
quantization_modes : List[UserBufferQuantizationMode] = None
if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
falls back to the legacy `use_fp8` parameter if `None` is provided.
dtype : torch.dtype = torch.bfloat16
non-FP8 data type of the communication buffer when `use_fp8 = False`
ub_cfgs: dict = None
......@@ -175,6 +190,7 @@ def initialize_ub(
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_wgrad"]`.
a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes`
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
......@@ -191,6 +207,28 @@ def initialize_ub(
+ "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
)
if not quantization_modes:
warnings.warn(
"Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
" instead.",
DeprecationWarning,
)
quantization_modes = [
UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE
]
else:
assert isinstance(quantization_modes, list), "quantization_modes must be a list"
assert all(
isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes
), "quantization_modes must be a list of UserBufferQuantizationMode"
if isinstance(ub_cfgs, dict) or ub_cfgs is None:
ub_cfgs = [ub_cfgs] * len(quantization_modes)
else:
assert len(ub_cfgs) == len(
quantization_modes
), "Number of ub_cfgs settings must match number of quantization configurations"
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
......@@ -349,6 +387,7 @@ def initialize_ub(
def add_ub(
name: str,
quantization_mode: UserBufferQuantizationMode,
method: str,
is_reduce_scatter: bool,
num_sm: int = 16,
......@@ -367,7 +406,9 @@ def initialize_ub(
warnings.warn(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
assert (
quantization_mode == UserBufferQuantizationMode.FP8
), "Atomic GEMM overlap supported only for FP8 GEMM."
if method in ("bulk", "external"):
warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
......@@ -407,7 +448,11 @@ def initialize_ub(
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
buffer_dtype = (
torch.uint8
if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf)
else dtype
)
if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P(
shape, # Communication buffer shape
......@@ -441,42 +486,52 @@ def initialize_ub(
comm_priority=comm_priority,
rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
)
_ub_communicators[name] = ub_obj
if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["bulk"].remove(name)
new_method = ub_cfgs[name]["method"]
methods[new_method].append(name)
for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
if name in remove_ag_gemm_dgrad:
continue
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs:
fp8_buf = (name in layers_all_gather_overlap) or (
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(ub_cfgs[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, **ub_cfg)
_ub_communicators[(name, quantization_mode)] = ub_obj
for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs):
if user_ub_cfg is not None:
for name in dgrad_reduce_scatter_overlap:
if (
name in user_ub_cfg
and "method" in user_ub_cfg[name]
and user_ub_cfg[name]["method"] != "bulk"
):
wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in user_ub_cfg
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["bulk"].remove(name)
new_method = user_ub_cfg[name]["method"]
methods[new_method].append(name)
for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
if name in remove_ag_gemm_dgrad:
continue
ub_cfg = get_default_config(name)
if user_ub_cfg is not None and name in user_ub_cfg:
fp8_buf = (name in layers_all_gather_overlap) or (
user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(user_ub_cfg[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, quantization_mode, **ub_cfg)
def get_ub(name: str):
def get_ub(name: str, use_fp8: bool):
"""Get userbuffer communicator corresponding to give key."""
# For now use `use_fp8` boolean input as it matches the current design in the modules
# So favour simplicity until the correct design becomes clear.
# This is mainly an internal API so we don't need to worry about future changes
key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE)
assert _ub_communicators is not None, "UB manager is not initialized."
assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered."
# assert name in _ub_communicators, f"UB for {name} is not registered."
if name in remove_ag_gemm_dgrad:
return None
return _ub_communicators[name]
return _ub_communicators[key]
def destroy_ub():
......@@ -1472,8 +1527,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
weight_tensor = noop_cat(self._get_weight_tensors())
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
......
......@@ -859,8 +859,7 @@ class GroupedLinear(TransformerEngineBaseModule):
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
if weight_params[i].grad is None:
weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
if self.use_bias:
for i in range(self.num_gemms):
if bias_params[i].grad is None:
......@@ -917,7 +916,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
......
......@@ -181,10 +181,10 @@ class _LayerNormLinear(torch.autograd.Function):
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_obj = get_ub(ub_name + "_fprop", fp8)
ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output
......@@ -361,8 +361,11 @@ class _LayerNormLinear(torch.autograd.Function):
# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
ln_out = ln_out_total = None
elif with_input_all_gather and not return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
ln_out_total = None
# ------------------------------------------------------
# Prepare output tensor
......@@ -608,23 +611,23 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8)
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ub_type_wgrad = tex.CommOverlapType.RS
# --------------------------------------------------
......@@ -802,7 +805,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8)
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -927,9 +930,19 @@ class _LayerNormLinear(torch.autograd.Function):
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor if permitted
if not ctx.return_layernorm_output:
# Deallocate input tensors if permitted
if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
# Input tensors have not been exposed externally
clear_tensor_data(ln_out)
elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered:
# Non-gathered input has not been exposed externally
clear_tensor_data(ln_out)
if ctx.ln_out_needs_gather:
# Gathered input is internal
clear_tensor_data(ln_out_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
......@@ -1209,7 +1222,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.return_layernorm_output_gathered = (
return_layernorm_output_gathered if return_layernorm_output else False
)
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
......@@ -1532,10 +1547,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
if get_ub(
self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
if get_ub(
self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled()
).is_fp8_ubuf():
fp8_grad = True
with torch.cuda.device(
......@@ -1803,7 +1822,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment