Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
......@@ -5,3 +5,58 @@
"""
Utils for the debug features.
"""
import torch
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup):
"""
Returns the statistics reduction parameters for the tensor.
"""
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
return skip_reduction, reduction_group, reduce_within_microbatch
def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration):
"""
Determines whether the feature should be enabled at the current iteration,
and computes the next iteration at which the feature will be enabled.
Returns
-------
run_current : bool
True if the feature should be enabled at the current iteration.
next_iter : int
The next iteration index at which the feature will be enabled.
"""
run_current = False
if start_end_list:
intervals = sorted(start_end_list)
else:
start_step = 0 if start_step is None else start_step
end = float("inf") if end_step is None else end_step
intervals = [(start_step, end)]
for s, e in intervals:
if iteration % freq == 0 and s <= iteration <= e:
run_current = True
first = max(iteration + 1, s)
offset = first % freq
candidate = first if offset == 0 else first + (freq - offset)
if candidate <= e:
return run_current, candidate
return run_current, None # No next iteration found
......@@ -10,6 +10,7 @@ When log() is called, they gather stats from all nodes, compute combined final s
from collections import defaultdict
from typing import Dict
import torch
from nvdlfw_inspect.utils import gather_along_first_dim
......@@ -20,6 +21,7 @@ from transformer_engine.debug.features.utils.stats_computation import (
DEPENDENCIES,
stats_to_num,
)
from transformer_engine.debug.pytorch.debug_state import TEDebugState
class _Buffer:
......@@ -65,14 +67,17 @@ class _Buffer:
gathered_buffer, _ = gather_along_first_dim(
self._buffer.unsqueeze(0), process_group=self.reduction_group
)
return gathered_buffer[mask.to(bool)]
return gathered_buffer[mask.to(torch.bool)]
def feed(self, tensor, iteration):
def feed(self, tensor, iteration, aux_dict=None):
"""
feed() is used to add tensor for computing the statistics.
Because of the microbatching, feed() can be used multiple
times for one log().
The aux_dict is used to share common computation between different stats.
For example for LogFp8TensorStats in can contain quantized tensors in different precisions.
The main reason of this design: need to combine results for already processed
tensors with the result of the new tensor.
"""
......@@ -95,7 +100,7 @@ class _Buffer:
# save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute:
fn, _ = STATS[stat_name]
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor)
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict)
# [num_buffers, num_stats]
buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0)
......@@ -106,7 +111,7 @@ class _Buffer:
self._new_buffer[stats_to_num[stat_name]] = combinator(buffers)
else:
fn = STATS[stat_name][0]
self._new_buffer[stats_to_num[stat_name]] = fn(tensor)
self._new_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict)
self._buffer.copy_(self._new_buffer)
......@@ -125,7 +130,6 @@ class _Buffer:
for stat_name in self.stats_to_log:
combiner = STATS[stat_name][1]
stat_value = combiner(gathered_helper_stats)
MetricLogger.log_scalar(
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
)
......@@ -146,10 +150,41 @@ class StatsBuffers:
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
# Logging stats involves synchronization between nodes
# and non-trivial cpu overhead.
# It should be only done if absolutely necessary.
# This variables helps to determine if we can reduce.
self.at_least_one_layer_fed = False
self.layers_to_next_iter: Dict[str, int] = {}
def _if_run_reduction(self) -> bool:
"""
Returns True if reduction should be run.
This is the case if at least one layer logged stats.
If not, it may be the case that some layer was not run on this node.
If we know that such layers on all other nodes do not log this time,
we can not reduce. If this in not the case, we should reduce.
To ensure corretness, we assume that every layer is invoked at first forward pass.
If this is not the case, hang might happen.
"""
if self.at_least_one_layer_fed:
return True
iteration = TEDebugState.get_iteration()
for _, next_iter in self.layers_to_next_iter.items():
# Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it.
if iteration >= next_iter:
return True
return False
def reset(self):
"""Resets all buffers."""
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
self.at_least_one_layer_fed = False
self.layers_to_next_iter: Dict[str, int] = {}
def try_add_buffer(
self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch
......@@ -161,14 +196,25 @@ class StatsBuffers:
self.buffers[(layer_name, tensor_name, options)] = buffer
self.reduction_group_to_buffer[reduction_group].append(buffer)
def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction):
"""Feeds the tensor into the respective buffer."""
def feed(
self, layer_name, tensor_name, options, tensor, iteration, skip_reduction, aux_dict=None
):
"""
Feeds the tensor into the respective buffer.
The aux_dict is used to share common computation between different stats.
For example for LogFp8TensorStats in can contain quantized tensors in different precisions.
"""
self.at_least_one_layer_fed = True
buffer = self.buffers[(layer_name, tensor_name, options)]
buffer.feed(tensor, iteration)
buffer.feed(tensor, iteration, aux_dict)
buffer.skip_reduction = skip_reduction
def log_stats(self):
"""Logs the stats from all the buffers."""
if not self._if_run_reduction():
return {}
output = {}
for reduction_group, buffers in self.reduction_group_to_buffer.items():
changed_buffers = [
......@@ -181,7 +227,7 @@ class StatsBuffers:
for _, buffer in changed_buffers:
stats = buffer.log()
output.update(stats)
self.at_least_one_layer_fed = False
return output
......
......@@ -8,8 +8,9 @@ Mathematical functions used to tensor statistics computation.
import math
import torch
MAX_FP8_VALUE_INT8 = 126
import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Format
@torch.compile
......@@ -49,6 +50,29 @@ def compute_std(variances, numels, sums):
return torch.sqrt(compute_variance(variances, numels, sums))
def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor):
"""Computes the overflows of the tensor."""
scale_inv = quantized_tensor._scale_inv
dtype = quantized_tensor._fp8_dtype
# Map each supported FP8 dtype to its corresponding max forward value.
dtype_to_max = {
tex.DType.kFloat8E4M3: Format.E4M3.value.max_fwd,
tex.DType.kFloat8E5M2: Format.E5M2.value.max_fwd,
}
if dtype not in dtype_to_max:
raise ValueError(
f"Unsupported FP8 dtype {dtype} passed to compute_fp8_delayed_scaling_overflows_num()."
)
fp8_max = dtype_to_max[dtype]
fp8_min = -fp8_max
overflows = (tensor > fp8_max * scale_inv) | (tensor < fp8_min * scale_inv)
return overflows.sum()
# buffers is tensor of shape [nr_buffers, nr_stats]
def _get(buffers, stat_name):
stat_nr = stats_to_num[stat_name]
......@@ -68,10 +92,12 @@ stats_to_num = {
"cur_amax": 9,
"dynamic_range_top": 10,
"dynamic_range_bottom": 11,
"underflows_num": 12,
"std": 13,
"dynamic_range": 14,
"underflows%": 15,
"std": 12,
"dynamic_range": 13,
"fp8_delayed_scaling_overflows_num": 14,
"fp8_delayed_scaling_overflows%": 15,
"overflows_num": 16,
"overflows%": 17,
}
DEPENDENCIES = {
......@@ -87,62 +113,207 @@ DEPENDENCIES = {
"cur_amax": {"cur_amax"},
"dynamic_range_top": {"dynamic_range_top"},
"dynamic_range_bottom": {"dynamic_range_bottom"},
"underflows_num": {"underflows_num"},
"std": {"variance", "numel", "sum"},
"dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"},
"underflows%": {"underflows_num", "numel"},
"fp8_delayed_scaling_overflows_num": {"fp8_delayed_scaling_overflows_num"},
"fp8_delayed_scaling_overflows%": {"fp8_delayed_scaling_overflows_num", "numel"},
"overflows_num": {"overflows_num"},
"overflows%": {"overflows_num", "numel"},
}
STATS = {
"min": (torch.min, lambda buffers: min(_get(buffers, "min"))),
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))),
"max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))),
"sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))),
"mean": (
lambda x, aux_dict: torch.mean(x),
lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")),
),
"numel": (
lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda buffers: sum(_get(buffers, "numel")),
),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l1_norm": (
lambda x, aux_dict: torch.norm(x, p=1),
lambda buffers: sum(_get(buffers, "l1_norm")),
),
"l2_norm_square": (
lambda x: torch.sum(x**2),
lambda x, aux_dict: torch.sum(x**2),
lambda buffers: sum(_get(buffers, "l2_norm_square")),
),
"l2_norm": (
lambda x: torch.norm(x, p=2),
lambda x, aux_dict: torch.norm(x, p=2),
lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))),
),
"variance": (
torch.var,
lambda x, aux_dict: torch.var(x),
lambda buffers: compute_variance(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"dynamic_range_top": (
_compute_dynamic_range_top,
lambda x, aux_dict: _compute_dynamic_range_top(x),
lambda buffers: max(_get(buffers, "dynamic_range_top")),
),
"dynamic_range_bottom": (
_compute_dynamic_range_bottom,
lambda x, aux_dict: _compute_dynamic_range_bottom(x),
lambda buffers: min(_get(buffers, "dynamic_range_bottom")),
),
"underflows_num": (
lambda x: (x.get_data_tensors()[0] == 0).sum(),
lambda buffers: sum(_get(buffers, "underflows_num")),
),
"std": (
torch.std,
lambda x, aux_dict: torch.std(x),
lambda buffers: compute_std(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"dynamic_range": (
lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda buffers: max(_get(buffers, "dynamic_range_top"))
- min(_get(buffers, "dynamic_range_bottom")),
),
"underflows%": (
lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
"fp8_delayed_scaling_overflows_num": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(
x, aux_dict["fp8_delayed_scaling"]
),
lambda buffers: sum(_get(buffers, "fp8_delayed_scaling_overflows_num")),
),
"fp8_delayed_scaling_overflows%": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(
x, aux_dict["fp8_delayed_scaling"]
)
/ x.numel()
* 100,
lambda buffers: 100
* sum(_get(buffers, "fp8_delayed_scaling_overflows_num"))
/ sum(_get(buffers, "numel")),
),
"overflows_num": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]),
lambda buffers: sum(_get(buffers, "overflows_num")),
),
"overflows%": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""])
/ x.numel()
* 100,
lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")),
),
}
def add_underflows_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* underflow stats (num and %) for the given recipe."""
columnwise_suffix = "_columnwise" if columnwise else ""
# Stat names
stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}"
stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
STATS[stat_num] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
- (x == 0).sum(),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
/ aux_dict[recipe_name].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
def add_scale_inv_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* scale-inv min and max stats for a given recipe.
This replaces the earlier separate helpers and avoids duplicated boilerplate.
"""
# Determine which attribute holds the scale-inverse tensor.
def get_scale_inv(quantized_tensor, columnwise):
if hasattr(quantized_tensor, "_scale_inv"):
return getattr(quantized_tensor, "_scale_inv")
if columnwise:
return getattr(quantized_tensor, "_columnwise_scale_inv")
return getattr(quantized_tensor, "_rowwise_scale_inv")
columnwise_suffix = "_columnwise" if columnwise else ""
# Prepare stat names.
stat_name_min = (
f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}"
)
stat_name_max = (
f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}"
)
# Assign indices in `stats_to_num` (order matters — keep insertion order deterministic).
stats_to_num[stat_name_min] = len(stats_to_num)
stats_to_num[stat_name_max] = len(stats_to_num)
# Capture the attribute name inside lambdas via default args to avoid late binding.
STATS[stat_name_min] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(),
lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)),
)
STATS[stat_name_max] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(),
lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)),
)
DEPENDENCIES[stat_name_min] = {stat_name_min}
DEPENDENCIES[stat_name_max] = {stat_name_max}
def add_mse_stats(recipe_name: str, columnwise: bool = False):
"""Register mse and total_square_error stats for the recipe."""
columnwise_suffix = "_columnwise" if columnwise else ""
stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}"
stat_err = (
f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}"
)
stats_to_num[stat_mse] = len(stats_to_num)
stats_to_num[stat_err] = len(stats_to_num)
STATS[stat_mse] = (
lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"),
lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err))
/ sum(_get(buffers, "numel")),
)
STATS[stat_err] = (
lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"),
lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)),
)
DEPENDENCIES[stat_err] = {stat_err}
DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"}
for _columnwise in [True, False]:
for _recipe_name in [
"", # default recipe
"fp8_delayed_scaling",
"mxfp8",
"fp8_current_scaling",
"fp8_block_scaling",
]:
add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise)
......@@ -22,6 +22,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from transformer_engine.debug.pytorch.debug_state import TEDebugState
aten = torch.ops.aten
......@@ -53,14 +54,13 @@ class DebugQuantizer(Quantizer):
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
self.iteration = TEDebugState.get_iteration()
# .internal = True is slightly faster, but results
# in errors when caching the weights.
......@@ -70,6 +70,12 @@ class DebugQuantizer(Quantizer):
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# next iteration when this quantizer will call any API
# it is None at the init and it is computed after_enabled api calls.
# None at the beginning means that if nothing will be done,
# this quantizer will never call any API.
self.next_debug_iter = None
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
......@@ -102,15 +108,21 @@ class DebugQuantizer(Quantizer):
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
inspect_tensor_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
)
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
modify_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
......@@ -121,10 +133,13 @@ class DebugQuantizer(Quantizer):
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
inspect_tensor_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
)
inspect_tensor_postquantize_enabled_rowwise = (
inspect_tensor_postquantize_enabled_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
......@@ -132,7 +147,8 @@ class DebugQuantizer(Quantizer):
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = (
inspect_tensor_postquantize_enabled_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
......@@ -140,7 +156,6 @@ class DebugQuantizer(Quantizer):
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
......@@ -158,42 +173,54 @@ class DebugQuantizer(Quantizer):
rowwise_plan = None
columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
modify_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
modify_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None:
......@@ -229,8 +256,11 @@ class DebugQuantizer(Quantizer):
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count,
"iteration": TEDebugState.get_iteration(),
"tp_group": self.tp_group,
"columnwise_quantized_tensor": columnwise_gemm_tensor,
"rowwise_quantized_tensor": rowwise_gemm_tensor,
"quantizer": self.parent_quantizer,
}
if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
......@@ -238,6 +268,10 @@ class DebugQuantizer(Quantizer):
if self.output_tensor:
return
del args["columnwise_quantized_tensor"]
del args["rowwise_quantized_tensor"]
del args["quantizer"]
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
......@@ -245,6 +279,7 @@ class DebugQuantizer(Quantizer):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
......@@ -270,22 +305,14 @@ class DebugQuantizer(Quantizer):
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
self._update_parent_quantizer_usage()
# Only columnwise quantization is not supported.
if self.parent_quantizer is not None:
if not self.parent_quantizer.rowwise_usage and self.parent_quantizer.columnwise_usage:
self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
......@@ -341,7 +368,6 @@ class DebugQuantizer(Quantizer):
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
original_tensor=tensor,
)
def process_gemm_output(self, tensor: torch.Tensor):
......@@ -375,6 +401,26 @@ class DebugQuantizer(Quantizer):
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
# pylint: disable=too-many-boolean-expressions
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
......@@ -446,29 +492,70 @@ class DebugQuantizer(Quantizer):
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
def get_next_debug_iter(self) -> Optional[int]:
"""
Returns the next iteration for which the debug is enabled for this tensor.
If the next iteration is None, then the debug is not enabled for this tensor.
"""
return self.next_debug_iter
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
def process_enabled_api_call(
self, enabled_call_output: bool | Tuple[bool, Optional[int]]
) -> bool:
"""
Process enabled API call output.
Updates self.next_debug_iter field accordingly.
Return the bool representing if the API call is enabled.
"""
if isinstance(enabled_call_output, tuple):
assert len(enabled_call_output) == 2, "Expected a tuple of length 2"
enabled_bool, next_iter = enabled_call_output
else:
enabled_bool = enabled_call_output
next_iter = self.iteration + 1
if self.next_debug_iter is None:
self.next_debug_iter = next_iter
elif next_iter is not None:
# If next iter is None, that means that call will never be enabled.
self.next_debug_iter = min(self.next_debug_iter, next_iter)
return enabled_bool
def supports_only_rowwise_all_gather(self) -> bool:
if self.parent_quantizer is not None:
return self.parent_quantizer.supports_only_rowwise_all_gather()
return False
def _update_parent_quantizer_usage(self):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize,
)
def set_usage(self, rowwise: bool = None, columnwise: bool = None):
"""
Sets the usage of the quantizer.
"""
super().set_usage(rowwise=rowwise, columnwise=columnwise)
if not self.output_tensor:
self._update_parent_quantizer_usage()
class DebugQuantizedTensor(QuantizedTensorBase):
"""
......@@ -484,7 +571,6 @@ class DebugQuantizedTensor(QuantizedTensorBase):
quantizer,
layer_name=None,
tensor_name=None,
original_tensor=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
......@@ -492,7 +578,6 @@ class DebugQuantizedTensor(QuantizedTensorBase):
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
self._original_tensor = original_tensor
def prepare_for_saving(self):
""" " Prepare for saving method override"""
......@@ -501,6 +586,7 @@ class DebugQuantizedTensor(QuantizedTensorBase):
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
......@@ -519,6 +605,7 @@ class DebugQuantizedTensor(QuantizedTensorBase):
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
......@@ -542,3 +629,27 @@ class DebugQuantizedTensor(QuantizedTensorBase):
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor."""
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor:
# If the same object is used both for rowwise and columnwise gemms,
# there is no benefit in erasing the usage of one of them.
# And there are scenarios when not deleting the usage of one of them is needed.
# For example when we want to recreate columnwise from rowwise.
if rowwise_usage is False:
self.rowwise_gemm_tensor = None
if columnwise_usage is False:
self.columnwise_gemm_tensor = None
if isinstance(self.rowwise_gemm_tensor, QuantizedTensor):
self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if isinstance(self.columnwise_gemm_tensor, QuantizedTensor):
self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if rowwise_usage and self.rowwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate rowwise tensor from columnwise tensor in debug mode."
)
if columnwise_usage and self.columnwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate columnwise tensor from rowwise tensor is debug mode."
)
......@@ -62,6 +62,13 @@ class TEDebugState:
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
@classmethod
def get_iteration(cls):
"""Returns the current iteration."""
import nvdlfw_inspect.api as debug_api
return debug_api.DEBUG_MANAGER._trainer_iteration_count
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
......
......@@ -4,6 +4,25 @@
"""Utils functions for the debug module."""
from typing import Optional
def next_iter_when_debug_should_be_run(quantizers) -> Optional[int]:
"""
Returns next iteration at which the debug should be run.
If debug will never be run for this layer, returns None.
"""
out = None
for q in quantizers:
if q.get_next_debug_iter() is not None:
if out is None:
out = q.get_next_debug_iter()
else:
out = min(out, q.get_next_debug_iter())
return out
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
......
......@@ -855,7 +855,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -872,6 +872,7 @@ def _fused_attn(
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
context_checkpoint_name: str = "context",
):
output, _ = _fused_attn_fwd_rule(
qkv,
......@@ -889,6 +890,7 @@ def _fused_attn(
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
)
return output
......@@ -909,6 +911,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
......@@ -927,9 +930,9 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
rng_state = checkpoint_name(rng_state, context_checkpoint_name)
return output, (
qkv,
bias,
......@@ -952,9 +955,11 @@ def _fused_attn_bwd_rule(
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
ctx,
dz,
):
del context_checkpoint_name
(
qkv,
bias,
......@@ -1012,6 +1017,7 @@ def fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
):
"""
Perform cuDNN fused attention.
......@@ -1044,6 +1050,7 @@ def fused_attn(
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -1116,6 +1123,7 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
)
return output
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Checkpoint policies for Transformer Engine in JAX.
This module provides JAX checkpoint policies that are compatible with Transformer Engine's custom primitives.
"""
import jax
from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive
__all__ = [
"te_gemms_saveable",
"dots_and_te_gemms_with_no_batch_dims",
"checkpoint_dots_and_te_gemms",
]
def te_gemms_saveable(prim, *_, **__) -> bool:
"""Checkpoint policy for Transformer Engine GEMMs."""
is_te_gemm = prim in {GemmPrimitive.outer_primitive, GroupedGemmPrimitive.outer_primitive}
# Workaround to include JAX's scaled_matmul until JAX checkpoint policies for dots are
# updated to include it.
is_jax_scaled_matmul = prim.name == "scaled_matmul_wrapper"
return is_te_gemm or is_jax_scaled_matmul
dots_and_te_gemms_with_no_batch_dims = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
te_gemms_saveable,
)
checkpoint_dots_and_te_gemms = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots,
te_gemms_saveable,
)
......@@ -915,11 +915,11 @@ register_primitive(BaseDActLuDBiasQuantizePrimitive)
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
......
......@@ -4,6 +4,7 @@
"""JAX/TE base custom ops"""
import os
import re
import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from packaging import version
......@@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta):
name = None
_is_enabled = True
# Default list of primitives to disable for all recipes
_default_disable_names = []
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.__name__` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
Determines if a custom call is enabled based on a state variable and environment variables.
Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern),
and finally to the internal state `_is_enabled` if neither is set.
Environment Variables:
1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives.
- Example 1 (global enable): 'true' enables all primitives.
- Example 2 (global disable): 'false' disables all primitives.
- Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state.
Note that the default state is set at class level based on _default_disable_names.
2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names.
- Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive.
- A deprecation warning is raised if used; it will be removed in future releases.
Behavior:
1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value.
2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching.
3. If neither is set, falls back to the internal state `_is_enabled`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.__name__) is not None
return is_enabled
# Check new key/value environment variable first
custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS")
if custom_calls_str is not None:
custom_calls_str = custom_calls_str.strip()
if custom_calls_str.lower() == "true":
return True
if custom_calls_str.lower() == "false":
return False
# Parse key=value pairs
settings = {}
for pair in custom_calls_str.split(","):
pair = pair.strip()
if "=" in pair:
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip().lower()
settings[key] = value == "true"
if cls.__name__ in settings:
return settings[cls.__name__]
# Check old regex environment variable (deprecated)
pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE")
if pattern_str is not None:
warnings.warn(
"NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use"
" NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.,"
" 'DBiasQuantizePrimitive=false').",
DeprecationWarning,
)
pattern = re.compile(pattern_str)
env_enabled = pattern.fullmatch(cls.__name__) is not None
return env_enabled
# If no environment variable is set, fall back to the internal state
return cls._is_enabled
@classmethod
def set_enabled(cls, enabled: bool):
"""
Sets the enabled state for this primitive.
"""
cls._is_enabled = enabled
@staticmethod
@abstractmethod
......@@ -109,10 +168,19 @@ class BasePrimitive(metaclass=ABCMeta):
return "... -> ..."
# Registry to store all registered primitive classes
_primitive_registry = {}
def register_primitive(cls):
"""
register jax primitive
Register a JAX primitive and add it to the internal registry.
"""
_primitive_registry[cls.__name__] = cls
# Set default disabled state at class level based on _default_disable_names
if cls.__name__ in BasePrimitive._default_disable_names:
cls.set_enabled(False)
def name_of_wrapper_p():
return cls.name + "_wrapper"
......@@ -145,3 +213,48 @@ def register_primitive(cls):
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None.
disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False.
Note:
1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied.
2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last.
"""
enable_set = set(enable_names or [])
disable_set = set(disable_names or [])
if disable_all_first:
for name, cls in _primitive_registry.items():
if (
isinstance(cls, type)
and issubclass(cls, BasePrimitive)
and cls is not BasePrimitive
):
cls.set_enabled(False)
# Apply enables
for name in enable_set:
cls = _primitive_registry.get(name)
if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
cls.set_enabled(True)
else:
raise ValueError(f"Primitive not found in registry: {name}")
# Apply disables (overrides enables if there's a conflict)
for name in disable_set:
cls = _primitive_registry.get(name)
if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
cls.set_enabled(False)
else:
raise ValueError(f"Primitive not found in registry: {name}")
......@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
impl_static_args = (6, 7, 8, 9, 10, 11, 12)
inner_primitive = None
outer_primitive = None
......@@ -169,16 +169,13 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
del use_split_accumulator
def _dims_are_consecutive(dims):
if len(dims) <= 1:
......@@ -201,27 +198,6 @@ class GemmPrimitive(BasePrimitive):
f"{rhs_contracting_dims}."
)
(
lhs_batch_dims,
rhs_batch_dims,
) = map(sanitize_dims, operand_ndims, batched_dims)
assert _dims_are_consecutive(lhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f"{lhs_batch_dims}."
)
assert _dims_are_consecutive(rhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f"{rhs_batch_dims}."
)
if len(lhs_batch_dims) == 0:
assert (
len(rhs_batch_dims) == 0
), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif len(rhs_batch_dims) != 0:
assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
bdim in rhs_contracting_dims for bdim in rhs_batch_dims
), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape),
......@@ -335,16 +311,14 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
del out_dtype
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
......@@ -385,9 +359,6 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
......@@ -402,14 +373,14 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_inv,
scaling_mode,
lhs.shape,
is_colwise=lhs_quantized_colwise,
is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv,
scaling_mode,
rhs.shape,
is_colwise=rhs_quantized_colwise,
is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
)
......@@ -422,9 +393,6 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
......@@ -436,12 +404,9 @@ class GemmPrimitive(BasePrimitive):
@staticmethod
def batcher(
batched_args,
jax_batch_dims,
batch_dims,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
......@@ -449,24 +414,13 @@ class GemmPrimitive(BasePrimitive):
use_split_accumulator,
):
assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args
lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
lhs_bdims, _, rhs_bdims, *_ = batch_dims
# Output is batched like the non-contracting batch dimensions of the LHS operand
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims)
lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims)
out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims
# Batched GEMM is not supported
assert (
lhs_bdims is None and rhs_bdims is None
), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})"
out_bdims = (None,)
# Bias gradient is never batched
bias_bdims = (None,)
......@@ -481,9 +435,6 @@ class GemmPrimitive(BasePrimitive):
*batched_args,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
......@@ -494,168 +445,85 @@ class GemmPrimitive(BasePrimitive):
)
@staticmethod
def _decompose_operand_specs(specs, contracting_dims, batch_dims):
ndims = len(specs)
cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims))
# Batch specs
bspecs = tuple(specs[i] for i in bdims)
# Non-batch leading dimension specs
lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims)
# Non-batch contracting dimension specs
cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims)
return bspecs, lspecs, cspecs
@staticmethod
def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims):
def _parse_operand_output_specs(
arg_infos,
contracting_dims,
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map(
sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims
)
(
(lhs_bspecs, lhs_lspecs, lhs_cspecs),
(rhs_bspecs, rhs_lspecs, rhs_cspecs),
) = map(
GemmPrimitive._decompose_operand_specs,
(lhs_specs, rhs_specs),
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
lhs_non_cdims, rhs_non_cdims = map(
lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims),
(lhs_ndim, rhs_ndim),
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map(
lambda specs, dims: tuple(specs[i] for i in dims),
(lhs_specs, lhs_specs, rhs_specs, rhs_specs),
(lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims),
)
# Batched dimensions must have the same sharding
if len(lhs_bdims) > 0 and len(rhs_bdims) > 0:
assert all(
lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs)
), (
"cuBLAS GEMM operand batch dimensions must have the same sharding: "
f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}."
reduce_spec = None
for l in lhs_cspecs:
for r in rhs_cspecs:
if l is not None and l == r:
assert reduce_spec is None, "Multiple reduce dimension is detected!"
reduce_spec = l
if reduce_spec is not None:
# Other non-reduce cdims (if exists) need to be unsharded
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
# Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
# No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
# In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`.
rhs_non_cspecs = tuple(
None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
)
# Only one each of the non-batched leading dimensions and non-batched contracting
# dimensions can be sharded
lhs_ldims, rhs_ldims = map(
lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude),
(lhs_ndim, rhs_ndim),
(lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims),
)
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map(
lambda specs: tuple(spec for spec in specs if spec is not None),
(lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs),
)
assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched leading dimension: "
f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}."
)
assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: "
f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}."
)
else:
# Otherwise, require contracting dims of both operands to be unsharded
lhs_cspecs = (None,) * len(lhs_cspecs)
rhs_cspecs = (None,) * len(rhs_cspecs)
# Extract single leading and contracting dimension specs
(lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map(
lambda specs: None if len(specs) == 0 else specs[0],
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none),
)
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
)
# Non-contracting dims of LHS to be gathered along the SP axis.
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands.
# 1. K1 == K2 != None and N == None
# LHS: (B, M, K)
# RHS: (B, None, K)
# OUT: (B, M, None) --(AR)-> (B, M, None)
# 2. K1 == K2 != None and M == N != None
# LHS: (B, M, K)
# RHS: (B, N, K)--(AG)->(B, None, K)
# OUT: (B, M, None) --(RS)--> (B, M, N)
# 3. M == N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, M, K)--(AG)->(B, None, None)
# OUT: (B, M, None)
# 4. M != N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, N, K)--(AG)->(B, N, None)
# OUT: (B, M, N)
reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec
all_reduce_output = reduce_flag and rhs_lspec is None
reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec
all_reduce_spec = reduce_scatter_spec = scatter_dim = None
lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs),
out_specs = lhs_non_cspecs + rhs_non_cspecs
# specs = merge(cspecs, non_cspecs)
lhs_specs, rhs_specs = map(
lambda cdims, cspecs, non_cspecs: (
cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs
),
(lhs_cdims, rhs_cdims),
(lhs_cspecs, rhs_cspecs),
(lhs_non_cspecs, rhs_non_cspecs),
)
out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs)
if reduce_scatter_output:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec = tuple(
rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim)
)
reduce_scatter_spec = lhs_cspec
scatter_dim = out_specs.index(rhs_lspec)
elif all_reduce_output:
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
all_reduce_spec = lhs_cspec
else:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
lhs_specs = tuple(
None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i]
for i in range(lhs_ndim)
)
rhs_specs = tuple(
None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i]
for i in range(rhs_ndim)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if rhs_lspec is not None and rhs_lspec in tuple(
lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
# Bias and Pre-GeLU sharding is based on GEMM output
bias_specs = out_specs[len(lhs_non_contracting_specs) :]
gelu_specs = out_specs
# Bias and Pre-GeLU sharding is based on GEMM output before any scatter
bias_specs = tuple(list(rhs_non_cspecs).copy())
gelu_specs = tuple(list(out_specs).copy())
return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
scatter_dim,
reduce_spec,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
......@@ -667,15 +535,13 @@ class GemmPrimitive(BasePrimitive):
):
del (
out_dtype,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
grad,
)
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
(_, (out_specs, dbias_specs, pre_gelu_specs), _) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
......@@ -695,9 +561,6 @@ class GemmPrimitive(BasePrimitive):
def partition(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
......@@ -712,10 +575,8 @@ class GemmPrimitive(BasePrimitive):
(
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
scatter_dim,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
reduce_spec,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
......@@ -762,9 +623,6 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
......@@ -772,19 +630,9 @@ class GemmPrimitive(BasePrimitive):
use_split_accumulator=use_split_accumulator,
)
# All-Reduce/Reduce-Scatter GEMM output
if all_reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec)
elif reduce_scatter_spec is not None:
outputs[0] = jax.lax.psum_scatter(
outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum_scatter(
outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
# All-Reduce GEMM output
if reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
return outputs
......@@ -794,9 +642,6 @@ class GemmPrimitive(BasePrimitive):
def shardy_sharding_rule(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
......@@ -806,40 +651,33 @@ class GemmPrimitive(BasePrimitive):
operand_types,
result_types,
):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
del out_dtype, grad, use_split_accumulator
del mesh, result_types
prefix = "GemmPrimitive_"
def _generate_operand_rules(name, ndim, cdims, bdims):
def _generate_operand_rules(name, ndim, cdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in bdims + cdims)
ldims = tuple(i for i in range(ndim) if i not in cdims)
for i in range(ndim):
dim_name = None
if i in bdims:
dim_idx = bdims.index(i) if len(bdims) > 1 else ""
dim_name = f"b{dim_idx}"
elif i in cdims:
dim_idx = cdims.index(i) if len(cdims) > 1 else ""
if i in cdims:
dim_idx = cdims.index(i)
dim_name = f"k{dim_idx}"
else:
dim_idx = ldims.index(i) if len(ldims) > 1 else ""
dim_idx = ldims.index(i)
dim_name = f"{name}_l{dim_idx}"
specs.append(prefix + dim_name)
return specs
lhs, _, rhs, *_ = operand_types
operand_ndims = (len(lhs.shape), len(rhs.shape))
(lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map(
lambda dims: map(sanitize_dims, operand_ndims, dims),
(contracting_dims, batched_dims),
)
(lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims)
lhs_specs, rhs_specs = map(
_generate_operand_rules,
("lhs", "rhs"),
operand_ndims,
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",)
......@@ -891,7 +729,6 @@ def _te_gemm(
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
......@@ -906,7 +743,6 @@ def _te_gemm(
scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
# Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
......@@ -925,7 +761,6 @@ def _te_gemm(
lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
......@@ -943,7 +778,6 @@ def _te_gemm(
rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
# Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
......@@ -961,9 +795,6 @@ def _te_gemm(
gelu_input,
out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims),
batched_dims=(lhs_bdims, rhs_bdims),
lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
......@@ -1171,10 +1002,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
if rhs.data_layout == "T":
rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
......@@ -1264,7 +1093,7 @@ def _jax_gemm(
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}")
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
......@@ -1286,7 +1115,6 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
**kwargs,
......@@ -1305,11 +1133,6 @@ def gemm(
Object for down-casting the RHS operand for quantized GEMM.
contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially
undesirable reduction in any batched contracting dimensions when invoked with sharded
operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM.
......@@ -1327,7 +1150,8 @@ def gemm(
TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM.
Returns
-------
......@@ -1358,12 +1182,12 @@ def gemm(
if not GemmPrimitive.enabled():
assert kwargs.get("bias", None) is None and not fuse_gelu, (
"TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
"TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
......@@ -1374,7 +1198,6 @@ def gemm(
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
**kwargs,
)
......
......@@ -519,11 +519,11 @@ register_primitive(BaseDBiasQuantizePrimitive)
class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_quantize(
......
......@@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits);
.Attr<bool>("is_grouped_dense_wgrad"));
} // namespace jax
} // namespace transformer_engine
......@@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits);
.Attr<int64_t>("flatten_axis"));
} // namespace jax
} // namespace transformer_engine
......@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
import warnings
from typing import Tuple, Sequence
from functools import partial
import jax
......@@ -23,16 +23,6 @@ from .quantize import (
)
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -40,7 +30,6 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -54,7 +43,6 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
......@@ -69,13 +57,34 @@ def dense(
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
@partial(
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
def _dense(
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -89,19 +98,30 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
)
return output
def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
):
"""Forward pass rule for dense layer transformation.
......@@ -119,23 +139,6 @@ def _dense_fwd_rule(
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
......@@ -158,7 +161,6 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
......@@ -175,13 +177,12 @@ def _dense_fwd_rule(
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
)
return output, ctx
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
......@@ -196,7 +197,6 @@ def _dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
......@@ -220,11 +220,11 @@ def _dense_bwd_rule(
k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......@@ -238,7 +238,6 @@ def _dense_bwd_rule(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -15,12 +15,14 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from ..dense import dense, _issue_batch_first_warning as _dense_warning
from transformer_engine.common import recipe
from ..dense import dense
from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -31,7 +33,6 @@ from ..cpp_extensions import (
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -274,10 +275,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
epsilon: float = 1e-6
......@@ -288,7 +285,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
......@@ -343,21 +339,28 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Base class of transformer engine
"""
def generate_quantizer_set(self, postfix: str = ""):
def generate_quantizer_set(
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
):
"""
Generate a set of FP8 meta for a GEMM.
"""
def generate_quantize_meta(quantizer_name: str):
collection_name = (
variable_collection
if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME
)
scale = self.variable(
QuantizeConfig.COLLECTION_NAME,
collection_name,
f"{quantizer_name}{postfix}_scale",
jnp.ones,
(1,),
jnp.float32,
).value
amax_history = self.variable(
QuantizeConfig.COLLECTION_NAME,
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,),
......@@ -365,7 +368,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
fp8_recipe, recipe.DelayedScaling
):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
......@@ -374,7 +379,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
else:
kwargs = {}
quantizer_set = QuantizerFactory.create_set(**kwargs)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
return quantizer_set
......@@ -420,10 +425,6 @@ class DenseGeneral(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
features: Union[Iterable[int], int]
......@@ -437,16 +438,9 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......@@ -628,10 +622,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
......@@ -657,18 +647,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0,
......@@ -936,15 +919,16 @@ class LayerNormMLP(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block.
ffn2_ckpt_name: str = "ffn2"
Checkpoint name for the output of the second fully-connected layer in the MLP block.
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
intermediate_dim: int = 2048
......@@ -973,18 +957,13 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......@@ -1146,9 +1125,6 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None
bias_2 = None
ffn1_ckpt_name = "ffn1"
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp:
out = layernorm_mlp(
y,
......@@ -1164,8 +1140,8 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
)
......@@ -1198,15 +1174,6 @@ class LayerNormMLP(TransformerEngineBase):
quantizer_set=ffn1_quantizer_set,
)
if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(
kernel_1.ndim, self.kernel_axes_1, contract_ind
),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_each_shape = (
kernel_1_each_shape[: len(axis)],
......@@ -1247,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
x = checkpoint_name(x, self.ffn1_ckpt_name)
if is_act_implemented:
z = activation(x, normalized_acts)
else:
......@@ -1310,7 +1277,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name)
out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output
......@@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact
def __call__(
......@@ -322,6 +323,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat
......@@ -348,6 +350,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence:
......@@ -369,6 +372,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
......@@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters
-----------------------
......@@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact
def __call__(
......@@ -690,6 +696,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)(
query,
key,
......@@ -1160,7 +1167,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_attention_heads * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.return_layernorm_output,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......@@ -1187,7 +1193,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=(self.return_layernorm_output or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......@@ -1212,7 +1217,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_proj = DenseGeneral(
axis=-1,
features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
use_bias=self.use_bias,
......@@ -1231,7 +1235,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
......@@ -1248,7 +1251,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......@@ -1413,7 +1415,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
out = DenseGeneral(
features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=(W_TP_AXES, W_FSDP_AXES),
......@@ -2015,7 +2016,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
......@@ -2070,7 +2070,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
scale_axes=(W_NO_SHARD_AXES,),
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
name="output_layernorm",
)(z)
......
......@@ -9,7 +9,6 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
import warnings
from functools import partial
from typing import Tuple
......@@ -26,16 +25,6 @@ from .quantize import (
)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -48,7 +37,6 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -69,7 +57,6 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -93,7 +80,6 @@ def layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -108,7 +94,6 @@ def layernorm_dense(
8,
9,
10,
11,
),
)
def _layernorm_dense(
......@@ -123,7 +108,6 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -143,7 +127,6 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers
Returns:
......@@ -161,7 +144,6 @@ def _layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -179,7 +161,6 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -197,17 +178,6 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
......@@ -236,7 +206,6 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
......@@ -260,7 +229,6 @@ def _layernorm_dense_fwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
)
return output, ctx
......@@ -271,9 +239,8 @@ def _layernorm_dense_bwd_rule(
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
dot_input_axes,
kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx,
grad,
):
......@@ -288,6 +255,7 @@ def _layernorm_dense_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del dot_input_axes
(
casted_ln_out,
casted_kernel,
......@@ -303,7 +271,6 @@ def _layernorm_dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
) = ctx
casted_grad, dbias = tex.quantize_dbias(
......@@ -328,7 +295,6 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......@@ -342,7 +308,6 @@ def _layernorm_dense_bwd_rule(
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -13,7 +13,6 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
import warnings
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
......@@ -29,17 +28,6 @@ from .quantize import (
noop_quantizer_set,
TensorUsage,
)
from .sharding import get_non_contracting_logical_axes
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp(
......@@ -59,7 +47,6 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
......@@ -91,7 +78,6 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
......@@ -137,13 +123,12 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -163,7 +148,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets,
):
"""Internal implementation of layernorm_mlp with custom VJP.
......@@ -189,7 +173,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets
Returns:
......@@ -214,7 +197,6 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
)
return output
......@@ -239,7 +221,6 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
):
"""Forward pass rule for layernorm_mlp.
......@@ -256,7 +237,7 @@ def _layernorm_mlp_fwd_rule(
Returns:
Tuple of (output, context) for automatic differentiation
"""
del kernel_2_axes
del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
......@@ -272,17 +253,6 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
......@@ -310,18 +280,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
)
if dot_1_input_axes is not None and kernel_1_axes is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1 and tex.gemm_uses_jax_dot():
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
......@@ -346,7 +308,6 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
)
......@@ -376,7 +337,6 @@ def _layernorm_mlp_fwd_rule(
use_bias_1,
use_bias_2,
quantizer_sets,
x_bdim,
)
return dot_2_output, ctx
......@@ -394,7 +354,6 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
ctx,
grad,
):
......@@ -411,7 +370,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
(
x,
mu,
......@@ -430,7 +389,6 @@ def _layernorm_mlp_bwd_rule(
use_bias_1,
use_bias_2,
quantizer_sets,
x_bdim,
) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
......@@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
......@@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule(
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -500,7 +456,6 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......@@ -511,7 +466,6 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
......@@ -16,12 +16,14 @@ import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
_get_scaling_mode,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
......@@ -878,11 +880,12 @@ class QuantizerFactory:
@staticmethod
def create_set(
n_quantizer_sets: int = 1,
scaling_mode: ScalingMode = None,
scaling_mode: Optional[ScalingMode] = None,
fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
n_groups: int = None,
fp8_recipe: Optional[recipe.Recipe] = None,
**kwargs,
) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers.
......@@ -894,12 +897,25 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer set or tuple of quantizer sets
"""
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
assert scaling_mode is None or fp8_recipe is None, (
"Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling"
" mode can be specified directly via the scaling_mode parameter or indirectly via"
" recipe. Recipe is preferred as it will support additional recipes in future where"
" scaling mode differs between x, kernel, and grad in the quantizer set."
)
if fp8_recipe is not None:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
scaling_mode = _get_scaling_mode(fp8_recipe)
else:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
if is_2x2x is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment