Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -15,10 +15,10 @@ from dataclasses import dataclass ...@@ -15,10 +15,10 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import warnings import warnings
from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import PartitionSpec from jax.interpreters import pxla
from jax.sharding import PartitionSpec, get_abstract_mesh
import numpy as np import numpy as np
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
...@@ -86,24 +86,29 @@ def get_sharding_map_logic_axis_to_mesh_axis(): ...@@ -86,24 +86,29 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names): def _generate_pspec(logical_axis_names):
""" """
Convert logical axes to PartitionSpec Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec.
Note, this method does not support Flax logical axes.
Args:
logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec.
Returns:
A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names
""" """
rules = get_sharding_map_logic_axis_to_mesh_axis() rules = get_sharding_map_logic_axis_to_mesh_axis()
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = [] mesh_axis_names = [rules.get(name) for name in logical_axis_names]
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names) pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec return pspec
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
""" """
A wrapper function to jax.lax.with_sharding_constraint to A wrapper function to jax.lax.with_sharding_constraint
support the case that Mesh is empty. 1. Does nothing if mesh is empty.
2. If all mesh axes are manual axes, replaces pspec with all Nones.
3. Otherwise, strips only the manual axes.
""" """
if pspec is None: if pspec is None:
return x return x
...@@ -111,7 +116,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): ...@@ -111,7 +116,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty: if mesh.empty:
return x return x
return jax.lax.with_sharding_constraint(x, pspec)
# We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one
manual_axis_names = get_abstract_mesh().manual_axes
cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec)
cleaned_pspec = PartitionSpec(*cleaned_axis_names)
return jax.lax.with_sharding_constraint(x, cleaned_pspec)
def with_sharding_constraint_by_logical_axes( def with_sharding_constraint_by_logical_axes(
...@@ -159,7 +171,7 @@ def with_sharding_constraint_by_logical_axes( ...@@ -159,7 +171,7 @@ def with_sharding_constraint_by_logical_axes(
# If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table
assert len(x.shape) == len(logical_axis_names) assert len(x.shape) == len(logical_axis_names)
pspec = generate_pspec(logical_axis_names) pspec = _generate_pspec(logical_axis_names)
return with_sharding_constraint(x, pspec) return with_sharding_constraint(x, pspec)
...@@ -383,24 +395,3 @@ class ShardingType(Enum): ...@@ -383,24 +395,3 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row") TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(
ndim, logical_axes: tuple[Optional[str]], contracting_dims
) -> tuple[Optional[str]]:
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
assert logical_axes is not None, "Logical axes must be a tuple and cannot be None."
assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions."
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
return non_contracting_logical_axes
...@@ -630,7 +630,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -630,7 +630,7 @@ class DotProductAttention(TransformerEngineBaseModule):
If true, there are padding tokens between individual sequences in a packed batch. If true, there are padding tokens between individual sequences in a packed batch.
""" """
with self.prepare_forward( with torch.cuda.device(query_layer.device), self.prepare_forward(
query_layer, query_layer,
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
......
...@@ -438,8 +438,8 @@ def get_attention_backend( ...@@ -438,8 +438,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None: if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0): if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12") logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12")
use_fused_attention = False use_fused_attention = False
if context_parallel: if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism") logger.debug("Disabling all backends for KV caching with context parallelism")
...@@ -625,7 +625,7 @@ def get_attention_backend( ...@@ -625,7 +625,7 @@ def get_attention_backend(
" bias for THD format" " bias for THD format"
) )
use_fused_attention = False use_fused_attention = False
elif fp8 and head_dim_qk != head_dim_v: elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v:
logger.debug( logger.debug(
"Disabling FusedAttention as it does not support context parallelism with FP8" "Disabling FusedAttention as it does not support context parallelism with FP8"
" MLA attention" " MLA attention"
......
...@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState ...@@ -11,7 +11,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim, SplitAlongDim,
...@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module): ...@@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. `fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False' qk_norm_type: Optional[str], default = None
if set to `True`, L2 normalization is applied to query and key tensors type of normalization to apply to query and key tensors.
after RoPE (if applicable) but before attention computation. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied.
This follows the Llama4 approach for QK normalization to improve When 'L2Normalization', L2 normalization is applied to query and key tensors.
training stability and model performance. When 'RMSNorm', RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach
for QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6 qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors. epsilon value for normalization of query and key tensors.
Only used when `use_qk_norm` is True. Only used when `qk_norm_type` is not None.
qk_norm_before_rope: bool, default = `False`
if set to `True`, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply
QK normalization at different points.
seq_length: Optional[int], default = `None` seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for fused functions are warmed up before training to ensure same kernels are used for
...@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -231,8 +240,9 @@ class MultiheadAttention(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
name: str = None, name: str = None,
use_qk_norm: bool = False, qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6, qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
) -> None: ) -> None:
...@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -264,6 +274,7 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved = False qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved self.qkv_weight_interleaved = qkv_weight_interleaved
self.rotary_pos_interleaved = rotary_pos_interleaved self.rotary_pos_interleaved = rotary_pos_interleaved
self.qk_norm_before_rope = qk_norm_before_rope
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
if layer_number is not None: if layer_number is not None:
...@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -288,7 +299,6 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name self.name = name
self.use_qk_norm = use_qk_norm
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
...@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module): ...@@ -300,13 +310,9 @@ class MultiheadAttention(torch.nn.Module):
"device": device, "device": device,
} }
# Initialize L2 normalization modules for query and key if enabled self.q_norm, self.k_norm = self._create_qk_norm_modules(
if self.use_qk_norm: qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size
self.qk_norm = L2Normalization( )
eps=qk_norm_eps,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
...@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module): ...@@ -427,6 +433,78 @@ class MultiheadAttention(torch.nn.Module):
**common_gemm_kwargs, **common_gemm_kwargs,
) )
def _create_qk_norm_modules(
self,
qk_norm_type: Optional[str],
qk_norm_eps: float,
device: Union[torch.device, str],
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]:
"""
Create query and key normalization modules based on the specified normalization type.
Parameters
----------
qk_norm_type : Optional[str]
Type of normalization to apply. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'
qk_norm_eps : float
Epsilon value for numerical stability
device : Union[torch.device, str]
Device to place the normalization modules on
seq_length : Optional[int], default = None
Sequence length for L2Normalization optimization
micro_batch_size : Optional[int], default = None
Micro batch size for L2Normalization optimization
Returns
-------
Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]
Query and key normalization modules (q_norm, k_norm)
"""
if qk_norm_type is None:
return None, None
if qk_norm_type == "L2Normalization":
l2_norm = L2Normalization(
eps=qk_norm_eps,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
# L2Normalization is parameter-free, so we can share the same instance
return l2_norm, l2_norm
if qk_norm_type == "RMSNorm":
q_norm = RMSNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
k_norm = RMSNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
return q_norm, k_norm
if qk_norm_type == "LayerNorm":
q_norm = LayerNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
k_norm = LayerNorm(
normalized_shape=self.hidden_size_per_attention_head,
eps=qk_norm_eps,
device=device,
)
return q_norm, k_norm
raise ValueError(
f"Unsupported QK norm type: {qk_norm_type}. "
"Supported types: ['L2Normalization', 'RMSNorm', 'LayerNorm']"
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
""" """
Set the tensor parallel group for the given Set the tensor parallel group for the given
...@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -789,6 +867,14 @@ class MultiheadAttention(torch.nn.Module):
) )
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ===========================
# Apply normalization to query and key tensors (before RoPE if configured)
# ===========================
if self.q_norm is not None and self.qk_norm_before_rope:
query_layer = self.q_norm(query_layer)
key_layer = self.k_norm(key_layer)
# ====================================================== # ======================================================
# Apply relative positional encoding (rotary embedding) # Apply relative positional encoding (rotary embedding)
# ====================================================== # ======================================================
...@@ -821,12 +907,19 @@ class MultiheadAttention(torch.nn.Module): ...@@ -821,12 +907,19 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
if pad_between_seqs:
rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded
rotary_pos_cu_seq_lens_kv = cu_seqlens_kv_padded
else:
rotary_pos_cu_seq_lens_q = cu_seqlens_q
rotary_pos_cu_seq_lens_kv = cu_seqlens_kv
query_layer = apply_rotary_pos_emb( query_layer = apply_rotary_pos_emb(
query_layer, query_layer,
q_pos_emb, q_pos_emb,
self.qkv_format, self.qkv_format,
fused=True, fused=True,
cu_seqlens=cu_seqlens_q, cu_seqlens=rotary_pos_cu_seq_lens_q,
cp_size=self.cp_size, cp_size=self.cp_size,
cp_rank=self.cp_rank, cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved, interleaved=self.rotary_pos_interleaved,
...@@ -836,19 +929,19 @@ class MultiheadAttention(torch.nn.Module): ...@@ -836,19 +929,19 @@ class MultiheadAttention(torch.nn.Module):
k_pos_emb, k_pos_emb,
self.qkv_format, self.qkv_format,
fused=True, fused=True,
cu_seqlens=cu_seqlens_kv, cu_seqlens=rotary_pos_cu_seq_lens_kv,
cp_size=self.cp_size, cp_size=self.cp_size,
cp_rank=self.cp_rank, cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved, interleaved=self.rotary_pos_interleaved,
) )
# =========================== # ===========================
# Apply L2 normalization to query and key tensors # Apply normalization to query and key tensors (after RoPE if not applied before)
# =========================== # ===========================
if self.use_qk_norm: if self.q_norm is not None and not self.qk_norm_before_rope:
query_layer = self.qk_norm(query_layer) query_layer = self.q_norm(query_layer)
key_layer = self.qk_norm(key_layer) key_layer = self.k_norm(key_layer)
# =========================== # ===========================
# Core attention computation # Core attention computation
......
...@@ -46,6 +46,15 @@ __all__ = [ ...@@ -46,6 +46,15 @@ __all__ = [
] ]
def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
if required:
return scale if scale is not None else 1.0
if scale not in (0.0, None):
raise ValueError("scale must be zero")
return 0.0
def general_gemm( def general_gemm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
...@@ -54,6 +63,8 @@ def general_gemm( ...@@ -54,6 +63,8 @@ def general_gemm(
quantization_params: Optional[Quantizer] = None, quantization_params: Optional[Quantizer] = None,
gelu: bool = False, gelu: bool = False,
gelu_in: torch.Tensor = None, gelu_in: torch.Tensor = None,
alpha: float = 1.0,
beta: Optional[float] = None,
accumulate: bool = False, accumulate: bool = False,
layout: str = "TN", layout: str = "TN",
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
...@@ -72,6 +83,9 @@ def general_gemm( ...@@ -72,6 +83,9 @@ def general_gemm(
transb = layout[1] == "T" transb = layout[1] == "T"
# assert quantization_params is None, "FP8 output not supported yet" # assert quantization_params is None, "FP8 output not supported yet"
alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
# if ub_type is not None: # if ub_type is not None:
# assert ub is not None, ( # assert ub is not None, (
# f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" # f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
...@@ -349,6 +363,8 @@ def general_gemm( ...@@ -349,6 +363,8 @@ def general_gemm(
"comm_type": ub_type, "comm_type": ub_type,
"extra_output": extra_output, "extra_output": extra_output,
"bulk_overlap": bulk_overlap, "bulk_overlap": bulk_overlap,
"alpha": alpha,
"beta": beta,
} }
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
......
...@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -431,7 +431,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor = self.fp8_tensor_object_map.pop(tensor_tag) tensor = self.fp8_tensor_object_map.pop(tensor_tag)
if self.double_buffering: if self.double_buffering:
tensor.do_not_clear = True tensor._do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None) self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward() # the tensor should have been copied back in on_group_commit_backward()
...@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -556,21 +556,33 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
for tensor_label, state in self.tensor_tag_to_state.items(): for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label group_id, _ = tensor_label
if group_id == group_to_reload: if group_id == group_to_reload:
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx]
else:
reload_buffer = None
if isinstance(state, tuple): if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload( recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx] state, True, reload_buffer
) )
buffer_idx = buffer_idx + 1 buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list): elif isinstance(state, list):
tensor_list = [] tensor_list = []
for state_tuple in state: for state_tuple in state:
if self.double_buffering:
reload_buffer = self.reload_double_buffer[double_buffer_idx][
buffer_idx
]
else:
reload_buffer = None
if isinstance(state_tuple, tuple): if isinstance(state_tuple, tuple):
tensor_list.append( tensor_list.append(
SynchronizedGroupOffloadHandler.reload( SynchronizedGroupOffloadHandler.reload(
state_tuple, state_tuple,
True, True,
self.reload_double_buffer[double_buffer_idx][buffer_idx], reload_buffer,
) )
) )
buffer_idx = buffer_idx + 1 buffer_idx = buffer_idx + 1
......
...@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function):
reduce_loss=False, reduce_loss=False,
dist_process_group=None, dist_process_group=None,
ignore_idx=-100, ignore_idx=-100,
is_cg_capturable=False,
): ):
""" """
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
...@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function):
tensor: The computed loss. tensor: The computed loss.
""" """
loss, _input = triton_cross_entropy.cross_entropy_forward( loss, _input = triton_cross_entropy.cross_entropy_forward(
_input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx _input,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
) )
ctx.save_for_backward(_input.detach()) ctx.save_for_backward(_input.detach())
ctx.is_cg_capturable = is_cg_capturable
return loss return loss
@staticmethod @staticmethod
...@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function):
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
""" """
(_input,) = ctx.saved_tensors (_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(_input, grad_output) _input = triton_cross_entropy.cross_entropy_backward(
_input, grad_output, ctx.is_cg_capturable
)
return ( return (
_input, _input,
None, None,
None, None,
None, None,
None, None,
None,
None,
) )
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
std::vector<size_t> getTensorShape(at::Tensor t) { std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape; std::vector<size_t> shape;
for (auto s : t.sizes()) { for (auto s : t.sizes()) {
shape.push_back(s); shape.push_back(s);
...@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) { ...@@ -286,7 +286,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim); return std::vector<size_t>(shape.data, shape.data + shape.ndim);
} }
int roundup(const int value, const int multiple) { size_t roundup(const size_t value, const size_t multiple) {
assert(multiple > 0); assert(multiple > 0);
return ((value + multiple - 1) / multiple) * multiple; return ((value + multiple - 1) / multiple) * multiple;
} }
......
...@@ -116,9 +116,21 @@ class Quantizer { ...@@ -116,9 +116,21 @@ class Quantizer {
virtual void set_quantization_params(TensorWrapper* tensor) const = 0; virtual void set_quantization_params(TensorWrapper* tensor) const = 0;
virtual std::pair<TensorWrapper, py::object> create_tensor( /*! @brief Construct a tensor with uninitialized data */
const std::vector<size_t>& shape, DType dtype, virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0; DType dtype) const = 0;
/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
* The PyTorch tensor's attributes are modified to match the
* quantizer's configuration.
*/
virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
py::object tensor) const = 0;
/*! @brief Convert to a quantized data format */
virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;
virtual ~Quantizer() = default; virtual ~Quantizer() = default;
...@@ -139,9 +151,17 @@ class NoneQuantizer : public Quantizer { ...@@ -139,9 +151,17 @@ class NoneQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override {} void set_quantization_params(TensorWrapper* tensor) const override {}
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, DType dtype) const override;
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Tensor data) const;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
}; };
class Float8Quantizer : public Quantizer { class Float8Quantizer : public Quantizer {
...@@ -157,9 +177,19 @@ class Float8Quantizer : public Quantizer { ...@@ -157,9 +177,19 @@ class Float8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override; void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, DType dtype) const override;
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose,
std::optional<at::Tensor> scale_inv) const;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
}; };
class Float8CurrentScalingQuantizer : public Quantizer { class Float8CurrentScalingQuantizer : public Quantizer {
...@@ -179,9 +209,29 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -179,9 +209,29 @@ class Float8CurrentScalingQuantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override; void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, DType dtype) const override;
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
/*! @brief Construct a high precision tensor giving it this quantizer's amax
Note: this member function also zeros out the amax, as it is meant to be used in conjunction with
a kernel computing the amax, which might expect the amax to be initialized to zero
*/
std::pair<TensorWrapper, py::object> create_hp_tensor_with_amax(const std::vector<size_t>& shape,
DType dtype);
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
/*! @brief Convert to a quantized data format avoiding amax computation */
void quantize_with_amax(TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt);
private:
void quantize_impl(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag, bool compute_amax);
}; };
class Float8BlockQuantizer : public Quantizer { class Float8BlockQuantizer : public Quantizer {
...@@ -213,9 +263,13 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -213,9 +263,13 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper // Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise // for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage. // and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, DType dtype) const override;
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const; std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
...@@ -230,16 +284,20 @@ class MXFP8Quantizer : public Quantizer { ...@@ -230,16 +284,20 @@ class MXFP8Quantizer : public Quantizer {
void set_quantization_params(TensorWrapper* tensor) const override; void set_quantization_params(TensorWrapper* tensor) const override;
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
const std::vector<size_t>& shape, DType dtype, DType dtype) const override;
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;
void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const; std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer); std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
std::vector<size_t> getTensorShape(at::Tensor t); std::vector<size_t> getTensorShape(const at::Tensor& t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe); const std::string& fp8_recipe);
...@@ -382,7 +440,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); ...@@ -382,7 +440,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0);
std::vector<size_t> convertShape(const NVTEShape& shape); std::vector<size_t> convertShape(const NVTEShape& shape);
int roundup(const int value, const int multiple); size_t roundup(const size_t value, const size_t multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
#include "common.h" #include "common.h"
class CommOverlapHelper;
class CommOverlap;
class CommOverlapP2P;
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
...@@ -118,7 +122,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -118,7 +122,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at::Tensor workspace, size_t workspaceSize, bool accumulate, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt, std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false,
float alpha = 1.0f, std::optional<float> beta = std::nullopt);
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B, std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
...@@ -179,6 +184,8 @@ std::vector<at::Tensor> te_batchgemm_ts( ...@@ -179,6 +184,8 @@ std::vector<at::Tensor> te_batchgemm_ts(
at::Tensor fp8_transpose(at::Tensor input, DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = std::nullopt);
/*************************************************************************************************** /***************************************************************************************************
* Activations * Activations
**************************************************************************************************/ **************************************************************************************************/
...@@ -455,6 +462,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k ...@@ -455,6 +462,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k
void nvshmem_finalize(); void nvshmem_finalize();
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream,
at::Stream recv_stream);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
/*************************************************************************************************** /***************************************************************************************************
...@@ -504,7 +518,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ...@@ -504,7 +518,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream(); std::pair<at::Stream, at::Stream> get_communication_stream();
}; // CommOverlap }; // CommOverlap
...@@ -525,7 +539,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ...@@ -525,7 +539,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream(); std::pair<at::Stream, at::Stream> get_communication_stream();
}; // CommOverlapP2P }; // CommOverlapP2P
......
...@@ -13,87 +13,92 @@ namespace transformer_engine::pytorch { ...@@ -13,87 +13,92 @@ namespace transformer_engine::pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)> template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
init_extension(); init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
input_shape[input_shape.size() - 1] /= shape_divisor;
auto fake_tensor_type = input.scalar_type();
auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// my_quantizer here has to be a Float8CurrentScalingQuantizer // Input tensor
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get()); auto input_tensor = input.contiguous();
if (my_quantizer_cs->with_amax_reduction) { const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
NVTE_ERROR(
"per-tensor current scaling amax reduction is not supported in activation functions."); // Construct output tensor
} auto quantizer_cpp = convert_quantizer(quantizer);
QuantizationConfigWrapper quant_config; const auto input_shape = input_cpp.shape();
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);
// Compute activation
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation in high-precision fused together with amax, then quantize.
NVTE_SCOPED_GIL_RELEASE({ auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
nvte_compute_scale_from_amax(te_output.data(), quant_config, auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype);
at::cuda::getCurrentCUDAStream()); NVTE_SCOPED_GIL_RELEASE(
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp);
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
} else { } else {
// Compute activation in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE(
{ act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); }); { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp->quantize(temp_cpp, out_cpp);
} }
return out; return out_py;
} }
template <void (*act_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)> template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) { py::handle quantizer) {
init_extension(); init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
auto grad_tensor = grad.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); // Grad output and input tensors
const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); auto grad_output_tensor = grad_output.contiguous();
const auto& te_input_shape = te_input.shape(); auto input_tensor = input.contiguous();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
auto fake_tensor_type = input.scalar_type(); const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);
auto [te_output, out] = // Construct grad input tensor
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape();
NVTE_SCOPED_GIL_RELEASE({ const std::vector<size_t> input_shape(input_shape_te.data,
act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); input_shape_te.data + input_shape_te.ndim);
}); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);
// Compute activation backward
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// Compute activation backward in high-precision fused together with amax, then quantize.
auto quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer*>(quantizer_cpp.get());
auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp);
} else {
// Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
}
return out; return grad_input_py;
} }
py::object gelu(const at::Tensor& input, py::handle quantizer) { py::object gelu(const at::Tensor& input, py::handle quantizer) {
......
...@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s ...@@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
auto max_tokens = shape[0]; auto max_tokens = shape[0];
auto fcd_size = 1; auto fcd_size = 1;
for (int i = 1; i <= shape.size(); i++) { for (size_t i = 1; i <= shape.size(); i++) {
fcd_size *= shape[i]; fcd_size *= shape[i];
} }
...@@ -110,8 +110,20 @@ std::vector<py::object> fused_attn_fwd( ...@@ -110,8 +110,20 @@ std::vector<py::object> fused_attn_fwd(
auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()}; auto o_shape = std::vector<size_t>{q_shape.begin(), q_shape.end()};
o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1];
py::object o_python, s_python; py::object o_python, s_python;
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); // Initialize FP8 tensor with scale-inverse
auto *O_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(O_quantizer.get());
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
}
auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()}; auto o_shape_int64 = std::vector<int64_t>{o_shape.begin(), o_shape.end()};
// construct NVTE tensors // construct NVTE tensors
...@@ -295,8 +307,20 @@ std::vector<py::object> fused_attn_bwd( ...@@ -295,8 +307,20 @@ std::vector<py::object> fused_attn_bwd(
py::object s_python, dp_python; py::object s_python, dp_python;
std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer); std::unique_ptr<Quantizer> S_quantizer = convert_quantizer(s_quantizer);
std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer); std::unique_ptr<Quantizer> dP_quantizer = convert_quantizer(dp_quantizer);
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
auto *S_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(S_quantizer.get());
auto *dP_quantizer_fp8 = dynamic_cast<Float8Quantizer *>(dP_quantizer.get());
NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt,
std::nullopt, std::nullopt);
} else {
std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32);
std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32);
}
std::vector<size_t> q_shape = convertShape(te_Q.shape()); std::vector<size_t> q_shape = convertShape(te_Q.shape());
std::vector<size_t> k_shape = convertShape(te_K.shape()); std::vector<size_t> k_shape = convertShape(te_K.shape());
...@@ -385,9 +409,22 @@ std::vector<py::object> fused_attn_bwd( ...@@ -385,9 +409,22 @@ std::vector<py::object> fused_attn_bwd(
default: default:
NVTE_ERROR("QKV layout not supported!"); NVTE_ERROR("QKV layout not supported!");
} }
std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); auto *fp8_quantizer = dynamic_cast<Float8Quantizer *>(dQKV_quantizer.get());
std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8");
std::tie(te_dQ, py_dQ) =
fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt);
std::tie(te_dK, py_dK) =
fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt);
std::tie(te_dV, py_dV) =
fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt);
} else {
auto *none_quantizer = dynamic_cast<NoneQuantizer *>(dQKV_quantizer.get());
NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8");
std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ);
std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK);
std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV);
}
// construct NVTE tensors // construct NVTE tensors
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
......
...@@ -4,80 +4,223 @@ ...@@ -4,80 +4,223 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <utility>
#include <vector>
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer) {
auto quantizer = convert_quantizer(py_quantizer); using namespace transformer_engine::pytorch::detail;
init_extension();
auto input_tensor = makeTransformerEngineTensor(input); // Grad output tensor
auto grad_output_torch = grad_output.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto shape = getTensorShape(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); // Construct grad bias tensor
const int64_t bias_size = static_cast<int64_t>(shape.back());
auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype);
auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch);
std::vector<size_t> output_shape; // Unquantized impl only requires computing grad bias
for (auto s : input.sizes()) { if (quantizer.is_none()) {
output_shape.emplace_back(static_cast<size_t>(s)); if (product(shape) == 0) {
grad_bias_torch.zero_();
} else {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
}
return {py::cast(std::move(grad_bias_torch)), py::cast(std::move(grad_output_torch))};
} }
auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype());
// Return immediately if tensors are empty // Construct grad input tensor
if (product(output_shape) == 0) { auto quantizer_cpp = convert_quantizer(quantizer);
return {py::cast(dbias.zero_()), out}; auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(shape, grad_output_dtype);
// Trivial impl if tensors are empty
if (product(shape) == 0) {
grad_bias_torch.zero_();
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
// Unfused impl if quantizer is not supported
const bool with_fused_dbias_quantize_kernel =
detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr());
if (!with_fused_dbias_quantize_kernel) {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte);
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
} }
auto dbias_tensor = makeTransformerEngineTensor(dbias); // Query workspace size
// Query workspace size and allocate workspace TensorWrapper workspace_nvte;
transformer_engine::TensorWrapper workspace; at::Tensor workspace_torch;
auto stream = at::cuda::getCurrentCUDAStream();
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace_nvte.data(), stream);
}); });
void* workspace_data_ptr = nullptr; // Allocate workspace
if (workspace.shape().ndim > 0) { if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) {
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype());
workspace_data_ptr = workspace_data.data_ptr(); workspace_nvte = makeTransformerEngineTensor(workspace_torch.data_ptr(), workspace_nvte.shape(),
} workspace_nvte.dtype());
workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype());
// Launch kernel
if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_tensor.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape);
} }
// Launch fused kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace_nvte.data(), stream);
}); });
return {py::cast(dbias), out}; return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
namespace {
std::vector<py::object> dact_dbias(
void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor,
cudaStream_t),
void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t),
at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py) {
using namespace transformer_engine::pytorch::detail;
init_extension();
// Grad output and activation input tensors
grad_output_torch = grad_output_torch.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto output_shape = getTensorShape(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
act_input_torch = act_input_torch.contiguous();
const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch);
const auto input_shape = getTensorShape(act_input_torch);
// Construct tensors
auto quantizer_cpp = convert_quantizer(quantizer_py);
auto [grad_input_nvte, grad_input_py] =
quantizer_cpp->create_tensor(input_shape, grad_output_dtype);
const int64_t bias_size = static_cast<int64_t>(input_shape.back());
auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype);
auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch);
// Return immediately if tensors are empty
if (product(output_shape) == 0) {
grad_bias_torch.zero_();
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
// Choose implementation
enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX };
Impl impl = Impl::UNFUSED;
if (detail::IsFloat8Quantizers(quantizer_py.ptr()) ||
detail::IsMXFP8Quantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_DBIAS_QUANTIZE;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_AMAX;
}
// Perform compute
auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Unfused dact, dbias, quantize
{
auto [temp_nvte, temp_py] =
NoneQuantizer(py::none()).create_tensor(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
break;
}
case Impl::FUSED_DACT_DBIAS_QUANTIZE:
// Fused dact-dbias-quantize kernel
{
// Query workspace size
TensorWrapper workspace_nvte;
NVTE_SCOPED_GIL_RELEASE({
dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(),
grad_bias_nvte.data(), workspace_nvte.data(), stream);
});
// Allocate workspace
at::Tensor workspace_torch;
if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) {
workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype());
workspace_nvte = makeTransformerEngineTensor(
workspace_torch.data_ptr(), workspace_nvte.shape(), workspace_nvte.dtype());
}
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(),
grad_bias_nvte.data(), workspace_nvte.data(), stream);
});
break;
}
case Impl::FUSED_DACT_AMAX:
// Fused dact-amax kernel, unfused dbias and quantize
{
auto *quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(quantizer_cpp_cs != nullptr,
"Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] =
quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte);
break;
}
default:
NVTE_ERROR("Invalid implementation");
}
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
} // namespace
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer);
} }
} // namespace transformer_engine::pytorch } // namespace pytorch
} // namespace transformer_engine
...@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) { ...@@ -28,60 +28,6 @@ std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
return std::vector<size_t>(shape.data, shape.data + shape.ndim); return std::vector<size_t>(shape.data, shape.data + shape.ndim);
} }
void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py,
std::unique_ptr<Quantizer> &quantizer_cpp, TensorWrapper &output,
TensorWrapper &noop_flag) {
// Check tensor dims
NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output),
"Input tensor (shape=", get_tensor_shape(input),
") and output tensor (shape=", get_tensor_shape(output), ") do not match");
if (input.numel() == 0) {
return;
}
// Recipe-specific configuration
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(noop_flag.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); });
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
output.set_amax(nullptr, DType::kFloat32, output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(quantizer_cpp.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
} // namespace } // namespace
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
...@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob ...@@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
const auto fake_dtype = input_cpp.dtype(); const auto fake_dtype = input_cpp.dtype();
std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype);
} else { } else {
output_py = output; std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output);
output_cpp = makeTransformerEngineTensor(output_py, quantizer);
} }
// Initialize no-op flag // Initialize no-op flag
TensorWrapper noop_flag_cpp; std::optional<TensorWrapper> noop_flag_cpp;
if (noop_flag.has_value()) { if (noop_flag.has_value()) {
noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); noop_flag_cpp = makeTransformerEngineTensor(*noop_flag);
} }
// Perform quantization // Perform quantization
quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp);
return output_py; return output_py;
} }
...@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list, ...@@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
}); });
} else { } else {
// Quantize kernels individually // Quantize kernels individually
TensorWrapper dummy_noop_flag;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]);
dummy_noop_flag);
} }
} }
} }
...@@ -455,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -455,11 +398,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
} }
// Allocate full buffer // Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>( auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views // Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
...@@ -498,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -498,11 +438,8 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
} }
// Allocate full buffer // Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>( auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views // Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
...@@ -650,66 +587,5 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -650,66 +587,5 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
return output_py_list; return output_py_list;
} }
template <void (*func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor,
cudaStream_t)>
std::vector<py::object> dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto grad_tensor = makeTransformerEngineTensor(grad_output);
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype());
auto act_input_tensor = makeTransformerEngineTensor(act_input);
const auto &shape = convertShape(grad_tensor.shape());
auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype());
auto dbias_tensor = makeTransformerEngineTensor(grad_bias);
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
return {py::cast(grad_bias), dact};
}
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dgelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsilu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_drelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dqgelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsrelu>(grad_output, act_input, quantizer);
}
} // namespace pytorch } // namespace pytorch
} // namespace transformer_engine } // namespace transformer_engine
...@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i ...@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
} }
at::Stream CommOverlap::get_communication_stream() { std::pair<at::Stream, at::Stream> CommOverlap::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()); // Return the same stream for both send and recv
return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()),
at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())};
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto ...@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
} }
at::Stream CommOverlapP2P::get_communication_stream() { std::pair<at::Stream, at::Stream> CommOverlapP2P::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device()); return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()),
at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())};
}
void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm(
CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) {
auto main_stream = at::cuda::getCurrentCUDAStream();
allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream),
at::cuda::CUDAStream(recv_stream), main_stream);
} }
...@@ -94,7 +94,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -94,7 +94,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
at::Tensor workspace, size_t workspaceSize, bool accumulate, at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore* comm_overlap, bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output, std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap) { bool bulk_overlap, float alpha, std::optional<float> beta) {
// Input tensors // Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
...@@ -112,6 +112,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -112,6 +112,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension");
NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension");
// Check scaling factors
if (accumulate) {
if (!beta) {
beta = 1.0f;
}
} else {
if (!beta) {
beta = 0.0f;
}
NVTE_CHECK(beta == 0.0, "Trying to use non-zero beta while not accumulating ",
"into D tensor. Beta has nothing to be applied to.");
}
// Output tensor // Output tensor
TensorWrapper D_tensor; TensorWrapper D_tensor;
if (D.is_none()) { if (D.is_none()) {
...@@ -240,9 +253,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -240,9 +253,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else { } else {
// Launch GEMM // Launch GEMM
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
accumulate, use_split_accumulator, num_math_sms, main_stream); te_workspace.data(), alpha, *beta, use_split_accumulator,
num_math_sms, main_stream);
}); });
} }
} else { } else {
...@@ -328,10 +342,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -328,10 +342,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector, std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector; te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers; std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, wrappers;
std::vector<at::Tensor> D_vectors; std::vector<at::Tensor> D_vectors;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto none = py::none(); auto none = py::none();
...@@ -398,10 +410,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -398,10 +410,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue; continue;
} }
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));
auto te_D = makeTransformerEngineTensor(out_tensor); auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
...@@ -421,18 +429,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -421,18 +429,25 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_bias_vector.emplace_back(te_bias.data()); te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
wrappers.emplace_back(std::move(te_A)); te_A_wrappers.emplace_back(std::move(te_A));
wrappers.emplace_back(std::move(te_B)); te_B_wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D)); wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias)); wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out)); wrappers.emplace_back(std::move(te_pre_gelu_out));
} }
// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa);
auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb);
for (size_t i = 0; i < workspace.size(); i++) { for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data()); te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp)); wrappers.emplace_back(std::move(wsp));
} }
// For now, we only have multi-stream cublas backend. // For now, we only have multi-stream cublas backend.
const char *NVTE_USE_HIPBLASLT_GROUPEDGEMM = std::getenv("NVTE_USE_HIPBLASLT_GROUPEDGEMM"); const char *NVTE_USE_HIPBLASLT_GROUPEDGEMM = std::getenv("NVTE_USE_HIPBLASLT_GROUPEDGEMM");
if(NVTE_USE_HIPBLASLT_GROUPEDGEMM != nullptr && NVTE_USE_HIPBLASLT_GROUPEDGEMM[0] == '1'){ if(NVTE_USE_HIPBLASLT_GROUPEDGEMM != nullptr && NVTE_USE_HIPBLASLT_GROUPEDGEMM[0] == '1'){
......
...@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream()); weight_decay, at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag ...@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda( nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id, beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype), bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
auto lr_cu = makeTransformerEngineTensor(lr); auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step); auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda( nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); inv_scale_cu.data(), at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl ...@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
auto lr_cu = makeTransformerEngineTensor(lr); auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step); auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda( nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); inv_scale_cu.data(), at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( ...@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda( nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream()); force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret); auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(), num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor, ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream()); max_chunks_per_tensor, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
...@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( ...@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
auto ret_cu = makeTransformerEngineTensor(ret); auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda( nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id, inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
......
...@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream()); num_tensors, scale, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment