Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -101,7 +101,7 @@ if __name__ == "__main__": ...@@ -101,7 +101,7 @@ if __name__ == "__main__":
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["jax", "flax>=0.7.1"], install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"], tests_require=["numpy"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
......
...@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names): ...@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
Convert logical axes to PartitionSpec Convert logical axes to PartitionSpec
""" """
rules = get_sharding_map_logic_axis_to_mesh_axis() rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names] # mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names) pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec return pspec
...@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t ...@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
""" """
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
""" """
if logical_axis_names is None: if not logical_axis_names:
return x return x
assert len(x.shape) == len(logical_axis_names) assert len(x.shape) == len(logical_axis_names)
...@@ -315,3 +319,25 @@ class ShardingType(Enum): ...@@ -315,3 +319,25 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row") TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if not logical_axes:
logical_axes = (None,) * ndim
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
return non_contracting_logical_axes
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
nvtx_range_pop, nvtx_range_pop,
...@@ -81,6 +82,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils ...@@ -81,6 +82,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
from .cpu_offload import mark_activation_offload
# Setup Attention Logging # Setup Attention Logging
...@@ -618,7 +620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -618,7 +620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(cp_group) rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
causal = "causal" in attn_mask_type causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type padding = "padding" in attn_mask_type
...@@ -1566,7 +1568,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1566,7 +1568,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
...@@ -4323,10 +4325,9 @@ class FlashAttention(torch.nn.Module): ...@@ -4323,10 +4325,9 @@ class FlashAttention(torch.nn.Module):
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled: if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] mark_activation_offload(
for tensor in tensor_list: query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
if tensor is not None: )
tensor.activation_offloading = True
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
# | API | use cases # | API | use cases
...@@ -4728,12 +4729,9 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4728,12 +4729,9 @@ class FusedAttnFunc(torch.autograd.Function):
else: else:
tensor_list = [q, k, v, out_save] tensor_list = [q, k, v, out_save]
tensor_list.extend(aux_ctx_tensors)
qkv_layout = "sbhd_sbhd_sbhd" qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list: mark_activation_offload(*tensor_list)
if tensor is not None: mark_activation_offload(*aux_ctx_tensors)
tensor.activation_offloading = True
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
...@@ -6482,6 +6480,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6482,6 +6480,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -6560,6 +6560,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6560,6 +6560,7 @@ class MultiheadAttention(torch.nn.Module):
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -6611,6 +6612,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6611,6 +6612,8 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group, "tp_group": tp_group,
...@@ -6651,6 +6654,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6651,6 +6654,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
name=name + ".layernorm_linear_qkv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -6662,6 +6666,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6662,6 +6666,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=parameters_split, parameters_split=parameters_split,
name=name + ".linear_qkv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
elif self.attention_type == "cross": elif self.attention_type == "cross":
...@@ -6683,6 +6688,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6683,6 +6688,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_name="qkv", ub_name="qkv",
name=name + ".layernorm_linear_q" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -6693,6 +6699,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6693,6 +6699,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
name=name + ".linear_q" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
self.key_value = Linear( self.key_value = Linear(
...@@ -6703,6 +6710,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6703,6 +6710,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("key", "value") if not fuse_qkv_params else None, parameters_split=("key", "value") if not fuse_qkv_params else None,
name=name + ".linear_kv" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -6732,6 +6740,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6732,6 +6740,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs, ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
ub_name="proj", ub_name="proj",
name=name + ".proj" if name is not None else None,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -6922,6 +6931,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6922,6 +6931,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# ================================================= # =================================================
# Pre-allocate memory for key-value cache for inference # Pre-allocate memory for key-value cache for inference
# ================================================= # =================================================
......
...@@ -24,6 +24,12 @@ TE_DType = { ...@@ -24,6 +24,12 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
TE_DType_To_Torch = { TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8, tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn, tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
......
...@@ -9,11 +9,11 @@ import os ...@@ -9,11 +9,11 @@ import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec, get_sm_count from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
...@@ -28,46 +28,6 @@ def _empty_tensor() -> torch.Tensor: ...@@ -28,46 +28,6 @@ def _empty_tensor() -> torch.Tensor:
return torch.Tensor().cuda() return torch.Tensor().cuda()
def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str):
"""Swizzle gemm inputs and return original scaling factor inverses."""
if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase):
return None
original_scale_inverses = (
A._rowwise_scale_inv,
A._columnwise_scale_inv,
B._rowwise_scale_inv,
B._columnwise_scale_inv,
)
if layout[0] == "T":
A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv)
else:
A._columnwise_scale_inv = tex.columnwise_swizzle(
A._columnwise_data, A._columnwise_scale_inv
)
if layout[1] == "N":
B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv)
else:
B._columnwise_scale_inv = tex.columnwise_swizzle(
B._columnwise_data, B._columnwise_scale_inv
)
return original_scale_inverses
def reset_swizzled_inputs(A, B, scale_inverses):
"""Reset the swizzled scale inverses after GEMM."""
if scale_inverses is not None:
(
A._rowwise_scale_inv,
A._columnwise_scale_inv,
B._rowwise_scale_inv,
B._columnwise_scale_inv,
) = scale_inverses
def general_gemm( def general_gemm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
...@@ -110,9 +70,20 @@ def general_gemm( ...@@ -110,9 +70,20 @@ def general_gemm(
if not out.is_contiguous(): if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.") raise ValueError("Output tensor is not contiguous.")
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
quantization_params = quantization_params.parent_quantizer
A = A.get_tensor(not transa)
B = B.get_tensor(transb)
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True
args = ( args = (
A, A,
transa, # transa transa, # transa
...@@ -138,9 +109,10 @@ def general_gemm( ...@@ -138,9 +109,10 @@ def general_gemm(
"bulk_overlap": bulk_overlap, "bulk_overlap": bulk_overlap,
} }
original_scale_inverses = swizzle_inputs(A, B, layout)
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
reset_swizzled_inputs(A, B, original_scale_inverses)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
return out, bias_grad, gelu_input, extra_output return out, bias_grad, gelu_input, extra_output
...@@ -170,14 +142,6 @@ def general_grouped_gemm( ...@@ -170,14 +142,6 @@ def general_grouped_gemm(
transa = layout[0] == "T" transa = layout[0] == "T"
transb = layout[1] == "T" transb = layout[1] == "T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor() empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms empty_tensors = [empty_tensor] * num_gemms
......
...@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"] ...@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False CPUOffloadEnabled = False
def set_offloading_param(tensor, param_name, value): def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor.""" """Set the type of the offloading needed for a tensor."""
assert param_name in ["weight_offloading", "activation_offloading"] for tensor in tensors:
if tensor is None: if tensor is None:
return continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]: if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
setattr(tensor, param_name, value) tensor.activation_offloading = True
else: else:
data_tensors = tensor.get_data_tensors() data_tensors = tensor.get_data_tensors()
for tensor in data_tensors: for tensor in data_tensors:
if tensor is not None: if tensor is not None:
setattr(tensor, param_name, value) tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool: def is_cpu_offload_enabled() -> bool:
...@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
torch.cuda.current_stream().wait_stream(self.d2h_stream) torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage # Time to free the activation memory after usage
for tensor_tag, _ in self.tensor_tag_to_buf.items(): for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count: if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group # Time to offload the next group
...@@ -538,7 +549,7 @@ def get_cpu_offload_context( ...@@ -538,7 +549,7 @@ def get_cpu_offload_context(
num_layers: int = 1, num_layers: int = 1,
model_layers: int = 1, model_layers: int = 1,
offload_activations: bool = True, offload_activations: bool = True,
offload_weights: bool = True, offload_weights: bool = False,
): ):
""" """
This function returns the CPU Offload context and the synchronizer function that needs to be This function returns the CPU Offload context and the synchronizer function that needs to be
...@@ -570,28 +581,30 @@ def get_cpu_offload_context( ...@@ -570,28 +581,30 @@ def get_cpu_offload_context(
""" """
def tensor_need_offloading_checker_activations(tensor): if not offload_weights and not offload_activations:
return hasattr(tensor, "activation_offloading")
# This includes the Gradient Accumulation Buffer
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor):
return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading")
if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
raise ValueError( raise ValueError(
"CPU Offloading is enabled while it is not " "CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)" "mentioned what to offload (weights/activations)"
) )
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers, num_offload_group=num_layers,
num_model_group=model_layers, num_model_group=model_layers,
......
...@@ -167,6 +167,38 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -167,6 +167,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
}; };
class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
DType dtype;
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
private:
int block_scaling_dim = 2;
public:
// Initializes from a python handle to a Float8BlockQuantizer
explicit Float8BlockQuantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override {
return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
}
// Gets rowwise and columnwise_data from tensor and sets them on wrapper
void set_quantization_params(TensorWrapper* tensor) const override;
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class MXFP8Quantizer : public Quantizer { class MXFP8Quantizer : public Quantizer {
public: public:
DType dtype; DType dtype;
......
...@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd( ...@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd( ...@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd(
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer); py::handle dp_quantizer, py::handle dqkv_quantizer);
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
...@@ -121,18 +121,22 @@ std::vector<at::Tensor> te_batchgemm_ts( ...@@ -121,18 +121,22 @@ std::vector<at::Tensor> te_batchgemm_ts(
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator); int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator);
#endif #endif
namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list, std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::handle>> output_list, std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, std::vector<py::handle> quantizer_list,
transformer_engine::DType otype); transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
} // namespace transformer_engine::pytorch
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
...@@ -285,16 +289,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio ...@@ -285,16 +289,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/ **************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory); const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory); const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const int cp_rank);
const at::Tensor &freqs, const int cp_size, const int cp_rank);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
/*************************************************************************************************** /***************************************************************************************************
* Miscellaneous * Miscellaneous
...@@ -394,10 +396,25 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -394,10 +396,25 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> padded_input_row_list); std::vector<size_t> padded_input_row_list);
/*************************************************************************************************** /***************************************************************************************************
* swizzle * NVSHMEM APIs
**************************************************************************************************/ **************************************************************************************************/
void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize();
} // namespace nvshmem_api
/***************************************************************************************************
* swizzle
**************************************************************************************************/
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);
......
...@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int ...@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
} else { } else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
} }
......
...@@ -7,138 +7,38 @@ ...@@ -7,138 +7,38 @@
#include "extensions.h" #include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) { const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float"); "Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output // output
auto act_options = input.options().requires_grad(false); auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device());
at::Tensor output; auto output = at::empty(input.sizes(), act_options);
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(2) >= freqs.size(3), TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is " "expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d) // input sizes: (t, h, d)
// t: cumulative sum of sequence lengths // t: cumulative sum of sequence lengths
// h: head num // h: head num
// d: dim of each head // d: dim of each head
const int t = input.size(0); // const int t = input.size(0);
const int h = input.size(1); const int h = input.size(1);
const int d = input.size(2); const int d = input.size(2);
// input strides // input strides
...@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ ...@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
const int stride_h = input.stride(1); const int stride_h = input.stride(1);
const int stride_d = input.stride(2); const int stride_d = input.stride(2);
// batch size // batch size
const int b = cu_seqlens.size(0) - 1; const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2) // freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0); const int max_s = freqs.size(0);
const int d2 = freqs.size(3); const int d2 = freqs.size(3);
// output auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto input_cu = makeTransformerEngineTensor(input); nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b,
auto freqs_cu = makeTransformerEngineTensor(freqs); h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return output;
}
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1);
const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1);
const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(d >= d2,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output; return output;
} }
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const at::Tensor &freqs, const int cp_size, const int cp_rank) { const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
auto act_options =
at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device());
auto input_grads = at::empty(output_grads.sizes(), act_options);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3), TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is " "expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d) // output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths // t: cumulative sum of sequence lengths
// h: head num // h: head num
// d: dim of each head // d: dim of each head
const int t = output_grads.size(0); // const int t = output_grads.size(0);
const int h = output_grads.size(1); const int h = output_grads.size(1);
const int d = output_grads.size(2); const int d = output_grads.size(2);
// output_grads strides // output_grads strides
...@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten ...@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
const int stride_h = output_grads.stride(1); const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2); const int stride_d = output_grads.stride(2);
// batch size // batch size
const int b = cu_seqlens.size(0) - 1; const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2) // freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0); const int max_s = freqs.size(0);
const int d2 = freqs.size(3); const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto output_grads_cu = makeTransformerEngineTensor(output_grads); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
auto freqs_cu = makeTransformerEngineTensor(freqs); max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
auto input_grads_cu = makeTransformerEngineTensor(input_grads); at::cuda::getCurrentCUDAStream());
return input_grads;
}
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1);
const int b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1);
const int stride_b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(d >= d2,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh" #include "kv_cache.cuh"
#include "thd_utils.cuh" #include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
...@@ -95,11 +97,11 @@ std::vector<py::object> fused_attn_fwd( ...@@ -95,11 +97,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
...@@ -289,8 +291,8 @@ std::vector<py::object> fused_attn_bwd( ...@@ -289,8 +291,8 @@ std::vector<py::object> fused_attn_bwd(
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors, const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) { py::handle dp_quantizer, py::handle dqkv_quantizer) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
...@@ -461,13 +463,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -461,13 +463,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec()); const std::vector<int64_t> &signed_shape = Aux_CTX_Tensors[i].sizes().vec();
auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end()); const std::vector<size_t> tmp(signed_shape.begin(), signed_shape.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
NVTEBasicTensor temp_data = { NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(), Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
temp_shape}; nvte_make_shape(tmp.data(), tmp.size())};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
} }
......
...@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (te_output.numel() == 0) return out; if (te_output.numel() == 0) return out;
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(te_noop.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
...@@ -61,14 +64,20 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -61,14 +64,20 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
allreduce_opts.reduceOp = c10d::ReduceOp::MAX; allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
} }
QuantizationConfigWrapper quant_config; // this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return out; return out;
......
...@@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool ...@@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()); char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (local_chunk) { if (local_chunk) {
if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!"); NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!"); NVTE_ERROR("input data type does not match communication buffer!");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size();
} else { } else {
if (input_tensor.numel() > (int64_t)_ubuf.numel()) if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!"); NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!"); NVTE_ERROR("input data type does not match communication buffer!");
} }
...@@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, ...@@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
if (shape.has_value()) { if (shape.has_value()) {
torch_shape = shape.value(); torch_shape = shape.value();
auto requested = product(torch_shape); size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!"); ") does not match allocated buffer size (", expected, ")!");
...@@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo ...@@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
if (local_chunk) { if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset // Copy input to the target ubuf chunk by rank offset
if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!"); NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!"); NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(), input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else { } else {
if (input_tensor.numel() > (int64_t)_ubuf.numel()) if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!"); NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!"); NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr, NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(), input_tensor.numel() * input_tensor.element_size(),
...@@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, ...@@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
if (shape.has_value()) { if (shape.has_value()) {
torch_shape = shape.value(); torch_shape = shape.value();
auto requested = product(torch_shape); size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!"); ") does not match allocated buffer size (", expected, ")!");
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "extensions.h" #include "extensions.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
#include "util.h"
namespace { namespace {
...@@ -179,8 +180,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -179,8 +180,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id); const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count); int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream(); auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) { if (comm_overlap) {
// Prepare extra output tensor // Prepare extra output tensor
TensorWrapper extra_output_tensor; TensorWrapper extra_output_tensor;
...@@ -317,17 +325,18 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -317,17 +325,18 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_vector, te_workspace_vector; te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers; std::vector<TensorWrapper> wrappers;
std::vector<at::Tensor> D_vectors; std::vector<at::Tensor> D_vectors;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto none = py::none(); auto none = py::none();
std::vector<size_t> single_output_begins; std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends; std::vector<size_t> single_output_ends;
int slicing_dim;
if (single_output && D == std::nullopt) { if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case."); NVTE_ERROR("not implemented, D should be allocated for single output case.");
} }
void* output_data_ptr; void* output_data_ptr = nullptr;
if (single_output) { if (single_output) {
output_data_ptr = (*D)[0].data_ptr(); output_data_ptr = (*D)[0].data_ptr();
} }
...@@ -384,6 +393,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -384,6 +393,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue; continue;
} }
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));
auto te_D = makeTransformerEngineTensor(out_tensor); auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// #include <torch/all.h> // #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream> #include <sstream>
...@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor { ...@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
n -= chunk_idx * chunk_size; n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) { for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8, float scale_val = transformer_engine::compute_scale_from_amax(
force_pow_2_scales, epsilon); amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits<float>::max());
scale[i_start] = scale_val; scale[i_start] = scale_val;
transformer_engine::reciprocal(scale_inv + i_start, scale_val); transformer_engine::reciprocal(scale_inv + i_start, scale_val);
} }
......
...@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
...@@ -166,14 +167,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -166,14 +167,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
allreduce_opts.reduceOp = c10d::ReduceOp::MAX; allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
} }
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer // my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
...@@ -309,14 +314,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -309,14 +314,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
allreduce_opts.reduceOp = c10d::ReduceOp::MAX; allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
} }
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
} }
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif
#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
nvshmemx_uniqueid_t id = {};
int my_rank = process_group->getRank();
int num_ranks = process_group->getSize();
if (my_rank == 0) {
nvshmemx_get_uniqueid(&id);
}
auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM");
auto datatensor =
torch::from_blob(reinterpret_cast<void *>(&id),
{static_cast<int64_t>(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor;
c10d::BroadcastOptions bcast_opts;
bcast_opts.rootRank = 0;
std::vector<torch::Tensor> datachunk = {datatmp};
auto work = process_group->broadcast(datachunk, bcast_opts);
work->wait();
if (backend_is_nccl) {
datatensor.copy_(datatmp.cpu());
datatmp = torch::Tensor();
}
nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
NVTE_CHECK(my_rank == nvshmem_my_pe(), "my_rank: ", my_rank,
" != nvshmem_my_pe(): ", nvshmem_my_pe());
NVTE_CHECK(num_ranks == nvshmem_n_pes(), "num_ranks: ", num_ranks,
" != nvshmem_n_pes(): ", nvshmem_n_pes());
#else
NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) {
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream();
WaitKind wait_kind_enum = WaitKind::STREAM_WAIT;
if (wait_kind == "kernel") {
wait_kind_enum = WaitKind::KERNEL_WAIT;
} else if (wait_kind == "nvshmem") {
wait_kind_enum = WaitKind::NVSHMEM_WAIT;
} else if (wait_kind == "stream") {
wait_kind_enum = WaitKind::STREAM_WAIT;
} else {
NVTE_ERROR("Invalid wait kind: ", wait_kind);
}
nvshmem_wait_on_stream(sig_addr, wait_kind_enum, cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype) {
#ifdef NVTE_ENABLE_NVSHMEM
auto option_gpu =
at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device());
auto size = torch::elementSize(dtype) *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
return at::from_blob(
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
#else
NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal) {
#ifdef NVTE_ENABLE_NVSHMEM
void *src_ptr = reinterpret_cast<void *>(src.data_ptr());
void *dst_ptr = reinterpret_cast<void *>(dst.data_ptr());
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
auto nelement = src.numel() * src.element_size();
uint64_t sigval = 1;
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();
nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET,
peer, (cudaStream_t)cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_finalize() {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize();
#else
NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
...@@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2."); NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2."); NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2.");
const int num_tensors = input_row_list.size(); const auto num_tensors = input_row_list.size();
// Extract properties from PyTorch tensors // Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list; std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list; std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
......
...@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK); num_tokens * topK);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc // Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
at::Tensor permuted_output = torch::empty( at::Tensor permuted_output =
{num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); torch::empty({num_out_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty( at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
...@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
int num_cols = input.size(1); int num_cols = input.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc // Output buffer alloc
at::Tensor unpermuted_output = torch::empty( at::Tensor unpermuted_output =
{num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); torch::empty({num_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1); int num_cols = input_bwd.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input_bwd.scalar_type();
// Output buffer alloc // Output buffer alloc
at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, at::Tensor act_grad =
torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); torch::empty({input_fwd.size(0), num_cols},
torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor prob_grad = torch::empty( at::Tensor prob_grad = torch::empty(
{num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
......
...@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; ...@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
void init_float8_extension() { void init_float8_extension() {
if (Float8TensorPythonClass) return; if (Float8TensorPythonClass) return;
...@@ -61,9 +64,31 @@ void init_mxfp8_extension() { ...@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
"Internal error: could not initialize pyTorch MXFP8 extension."); "Internal error: could not initialize pyTorch MXFP8 extension.");
} }
void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorBasePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
}
void init_extension() { void init_extension() {
init_float8_extension(); init_float8_extension();
init_mxfp8_extension(); init_mxfp8_extension();
init_float8blockwise_extension();
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("output") = py::none(), py::arg("noop") = py::none()); py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype")); py::arg("otype"));
m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",
...@@ -170,15 +196,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -170,15 +196,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma")); py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); "Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM #ifdef USE_ROCM
m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif #endif
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>()); py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax")); m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
...@@ -206,10 +234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -206,10 +234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
...@@ -240,6 +264,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -240,6 +264,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format", "Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions // multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors", "Fused overflow check + scale for a list of contiguous tensors",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment