Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -5,3 +5,58 @@ ...@@ -5,3 +5,58 @@
""" """
Utils for the debug features. Utils for the debug features.
""" """
import torch
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup):
"""
Returns the statistics reduction parameters for the tensor.
"""
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
return skip_reduction, reduction_group, reduce_within_microbatch
def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration):
"""
Determines whether the feature should be enabled at the current iteration,
and computes the next iteration at which the feature will be enabled.
Returns
-------
run_current : bool
True if the feature should be enabled at the current iteration.
next_iter : int
The next iteration index at which the feature will be enabled.
"""
run_current = False
if start_end_list:
intervals = sorted(start_end_list)
else:
start_step = 0 if start_step is None else start_step
end = float("inf") if end_step is None else end_step
intervals = [(start_step, end)]
for s, e in intervals:
if iteration % freq == 0 and s <= iteration <= e:
run_current = True
first = max(iteration + 1, s)
offset = first % freq
candidate = first if offset == 0 else first + (freq - offset)
if candidate <= e:
return run_current, candidate
return run_current, None # No next iteration found
...@@ -10,6 +10,7 @@ When log() is called, they gather stats from all nodes, compute combined final s ...@@ -10,6 +10,7 @@ When log() is called, they gather stats from all nodes, compute combined final s
from collections import defaultdict from collections import defaultdict
from typing import Dict
import torch import torch
from nvdlfw_inspect.utils import gather_along_first_dim from nvdlfw_inspect.utils import gather_along_first_dim
...@@ -20,6 +21,7 @@ from transformer_engine.debug.features.utils.stats_computation import ( ...@@ -20,6 +21,7 @@ from transformer_engine.debug.features.utils.stats_computation import (
DEPENDENCIES, DEPENDENCIES,
stats_to_num, stats_to_num,
) )
from transformer_engine.debug.pytorch.debug_state import TEDebugState
class _Buffer: class _Buffer:
...@@ -65,14 +67,17 @@ class _Buffer: ...@@ -65,14 +67,17 @@ class _Buffer:
gathered_buffer, _ = gather_along_first_dim( gathered_buffer, _ = gather_along_first_dim(
self._buffer.unsqueeze(0), process_group=self.reduction_group self._buffer.unsqueeze(0), process_group=self.reduction_group
) )
return gathered_buffer[mask.to(bool)] return gathered_buffer[mask.to(torch.bool)]
def feed(self, tensor, iteration): def feed(self, tensor, iteration, aux_dict=None):
""" """
feed() is used to add tensor for computing the statistics. feed() is used to add tensor for computing the statistics.
Because of the microbatching, feed() can be used multiple Because of the microbatching, feed() can be used multiple
times for one log(). times for one log().
The aux_dict is used to share common computation between different stats.
For example for LogFp8TensorStats in can contain quantized tensors in different precisions.
The main reason of this design: need to combine results for already processed The main reason of this design: need to combine results for already processed
tensors with the result of the new tensor. tensors with the result of the new tensor.
""" """
...@@ -95,7 +100,7 @@ class _Buffer: ...@@ -95,7 +100,7 @@ class _Buffer:
# save stats for tensor to tmp buffer # save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute: for stat_name in self.stats_to_compute:
fn, _ = STATS[stat_name] fn, _ = STATS[stat_name]
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor) self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict)
# [num_buffers, num_stats] # [num_buffers, num_stats]
buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0) buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0)
...@@ -106,7 +111,7 @@ class _Buffer: ...@@ -106,7 +111,7 @@ class _Buffer:
self._new_buffer[stats_to_num[stat_name]] = combinator(buffers) self._new_buffer[stats_to_num[stat_name]] = combinator(buffers)
else: else:
fn = STATS[stat_name][0] fn = STATS[stat_name][0]
self._new_buffer[stats_to_num[stat_name]] = fn(tensor) self._new_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict)
self._buffer.copy_(self._new_buffer) self._buffer.copy_(self._new_buffer)
...@@ -125,7 +130,6 @@ class _Buffer: ...@@ -125,7 +130,6 @@ class _Buffer:
for stat_name in self.stats_to_log: for stat_name in self.stats_to_log:
combiner = STATS[stat_name][1] combiner = STATS[stat_name][1]
stat_value = combiner(gathered_helper_stats) stat_value = combiner(gathered_helper_stats)
MetricLogger.log_scalar( MetricLogger.log_scalar(
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
) )
...@@ -146,10 +150,41 @@ class StatsBuffers: ...@@ -146,10 +150,41 @@ class StatsBuffers:
self.buffers = {} # (layer_name, tensor_name) -> buffer self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list) self.reduction_group_to_buffer = defaultdict(list)
# Logging stats involves synchronization between nodes
# and non-trivial cpu overhead.
# It should be only done if absolutely necessary.
# This variables helps to determine if we can reduce.
self.at_least_one_layer_fed = False
self.layers_to_next_iter: Dict[str, int] = {}
def _if_run_reduction(self) -> bool:
"""
Returns True if reduction should be run.
This is the case if at least one layer logged stats.
If not, it may be the case that some layer was not run on this node.
If we know that such layers on all other nodes do not log this time,
we can not reduce. If this in not the case, we should reduce.
To ensure corretness, we assume that every layer is invoked at first forward pass.
If this is not the case, hang might happen.
"""
if self.at_least_one_layer_fed:
return True
iteration = TEDebugState.get_iteration()
for _, next_iter in self.layers_to_next_iter.items():
# Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it.
if iteration >= next_iter:
return True
return False
def reset(self): def reset(self):
"""Resets all buffers.""" """Resets all buffers."""
self.buffers = {} # (layer_name, tensor_name) -> buffer self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list) self.reduction_group_to_buffer = defaultdict(list)
self.at_least_one_layer_fed = False
self.layers_to_next_iter: Dict[str, int] = {}
def try_add_buffer( def try_add_buffer(
self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch
...@@ -161,14 +196,25 @@ class StatsBuffers: ...@@ -161,14 +196,25 @@ class StatsBuffers:
self.buffers[(layer_name, tensor_name, options)] = buffer self.buffers[(layer_name, tensor_name, options)] = buffer
self.reduction_group_to_buffer[reduction_group].append(buffer) self.reduction_group_to_buffer[reduction_group].append(buffer)
def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction): def feed(
"""Feeds the tensor into the respective buffer.""" self, layer_name, tensor_name, options, tensor, iteration, skip_reduction, aux_dict=None
):
"""
Feeds the tensor into the respective buffer.
The aux_dict is used to share common computation between different stats.
For example for LogFp8TensorStats in can contain quantized tensors in different precisions.
"""
self.at_least_one_layer_fed = True
buffer = self.buffers[(layer_name, tensor_name, options)] buffer = self.buffers[(layer_name, tensor_name, options)]
buffer.feed(tensor, iteration) buffer.feed(tensor, iteration, aux_dict)
buffer.skip_reduction = skip_reduction buffer.skip_reduction = skip_reduction
def log_stats(self): def log_stats(self):
"""Logs the stats from all the buffers.""" """Logs the stats from all the buffers."""
if not self._if_run_reduction():
return {}
output = {} output = {}
for reduction_group, buffers in self.reduction_group_to_buffer.items(): for reduction_group, buffers in self.reduction_group_to_buffer.items():
changed_buffers = [ changed_buffers = [
...@@ -181,7 +227,7 @@ class StatsBuffers: ...@@ -181,7 +227,7 @@ class StatsBuffers:
for _, buffer in changed_buffers: for _, buffer in changed_buffers:
stats = buffer.log() stats = buffer.log()
output.update(stats) output.update(stats)
self.at_least_one_layer_fed = False
return output return output
......
...@@ -8,8 +8,9 @@ Mathematical functions used to tensor statistics computation. ...@@ -8,8 +8,9 @@ Mathematical functions used to tensor statistics computation.
import math import math
import torch import torch
import torch.nn.functional as F
MAX_FP8_VALUE_INT8 = 126 import transformer_engine_torch as tex
from transformer_engine.common.recipe import Format
@torch.compile @torch.compile
...@@ -49,6 +50,29 @@ def compute_std(variances, numels, sums): ...@@ -49,6 +50,29 @@ def compute_std(variances, numels, sums):
return torch.sqrt(compute_variance(variances, numels, sums)) return torch.sqrt(compute_variance(variances, numels, sums))
def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor):
"""Computes the overflows of the tensor."""
scale_inv = quantized_tensor._scale_inv
dtype = quantized_tensor._fp8_dtype
# Map each supported FP8 dtype to its corresponding max forward value.
dtype_to_max = {
tex.DType.kFloat8E4M3: Format.E4M3.value.max_fwd,
tex.DType.kFloat8E5M2: Format.E5M2.value.max_fwd,
}
if dtype not in dtype_to_max:
raise ValueError(
f"Unsupported FP8 dtype {dtype} passed to compute_fp8_delayed_scaling_overflows_num()."
)
fp8_max = dtype_to_max[dtype]
fp8_min = -fp8_max
overflows = (tensor > fp8_max * scale_inv) | (tensor < fp8_min * scale_inv)
return overflows.sum()
# buffers is tensor of shape [nr_buffers, nr_stats] # buffers is tensor of shape [nr_buffers, nr_stats]
def _get(buffers, stat_name): def _get(buffers, stat_name):
stat_nr = stats_to_num[stat_name] stat_nr = stats_to_num[stat_name]
...@@ -68,10 +92,12 @@ stats_to_num = { ...@@ -68,10 +92,12 @@ stats_to_num = {
"cur_amax": 9, "cur_amax": 9,
"dynamic_range_top": 10, "dynamic_range_top": 10,
"dynamic_range_bottom": 11, "dynamic_range_bottom": 11,
"underflows_num": 12, "std": 12,
"std": 13, "dynamic_range": 13,
"dynamic_range": 14, "fp8_delayed_scaling_overflows_num": 14,
"underflows%": 15, "fp8_delayed_scaling_overflows%": 15,
"overflows_num": 16,
"overflows%": 17,
} }
DEPENDENCIES = { DEPENDENCIES = {
...@@ -87,62 +113,207 @@ DEPENDENCIES = { ...@@ -87,62 +113,207 @@ DEPENDENCIES = {
"cur_amax": {"cur_amax"}, "cur_amax": {"cur_amax"},
"dynamic_range_top": {"dynamic_range_top"}, "dynamic_range_top": {"dynamic_range_top"},
"dynamic_range_bottom": {"dynamic_range_bottom"}, "dynamic_range_bottom": {"dynamic_range_bottom"},
"underflows_num": {"underflows_num"},
"std": {"variance", "numel", "sum"}, "std": {"variance", "numel", "sum"},
"dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"}, "dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"},
"underflows%": {"underflows_num", "numel"}, "fp8_delayed_scaling_overflows_num": {"fp8_delayed_scaling_overflows_num"},
"fp8_delayed_scaling_overflows%": {"fp8_delayed_scaling_overflows_num", "numel"},
"overflows_num": {"overflows_num"},
"overflows%": {"overflows_num", "numel"},
} }
STATS = { STATS = {
"min": (torch.min, lambda buffers: min(_get(buffers, "min"))), "min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))),
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))), "max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), "sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), "mean": (
lambda x, aux_dict: torch.mean(x),
lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")),
),
"numel": ( "numel": (
lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda buffers: sum(_get(buffers, "numel")), lambda buffers: sum(_get(buffers, "numel")),
), ),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), "l1_norm": (
lambda x, aux_dict: torch.norm(x, p=1),
lambda buffers: sum(_get(buffers, "l1_norm")),
),
"l2_norm_square": ( "l2_norm_square": (
lambda x: torch.sum(x**2), lambda x, aux_dict: torch.sum(x**2),
lambda buffers: sum(_get(buffers, "l2_norm_square")), lambda buffers: sum(_get(buffers, "l2_norm_square")),
), ),
"l2_norm": ( "l2_norm": (
lambda x: torch.norm(x, p=2), lambda x, aux_dict: torch.norm(x, p=2),
lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))), lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))),
), ),
"variance": ( "variance": (
torch.var, lambda x, aux_dict: torch.var(x),
lambda buffers: compute_variance( lambda buffers: compute_variance(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
), ),
), ),
"cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), "cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"dynamic_range_top": ( "dynamic_range_top": (
_compute_dynamic_range_top, lambda x, aux_dict: _compute_dynamic_range_top(x),
lambda buffers: max(_get(buffers, "dynamic_range_top")), lambda buffers: max(_get(buffers, "dynamic_range_top")),
), ),
"dynamic_range_bottom": ( "dynamic_range_bottom": (
_compute_dynamic_range_bottom, lambda x, aux_dict: _compute_dynamic_range_bottom(x),
lambda buffers: min(_get(buffers, "dynamic_range_bottom")), lambda buffers: min(_get(buffers, "dynamic_range_bottom")),
), ),
"underflows_num": (
lambda x: (x.get_data_tensors()[0] == 0).sum(),
lambda buffers: sum(_get(buffers, "underflows_num")),
),
"std": ( "std": (
torch.std, lambda x, aux_dict: torch.std(x),
lambda buffers: compute_std( lambda buffers: compute_std(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
), ),
), ),
"dynamic_range": ( "dynamic_range": (
lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda buffers: max(_get(buffers, "dynamic_range_top")) lambda buffers: max(_get(buffers, "dynamic_range_top"))
- min(_get(buffers, "dynamic_range_bottom")), - min(_get(buffers, "dynamic_range_bottom")),
), ),
"underflows%": ( "fp8_delayed_scaling_overflows_num": (
lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100, lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), x, aux_dict["fp8_delayed_scaling"]
),
lambda buffers: sum(_get(buffers, "fp8_delayed_scaling_overflows_num")),
),
"fp8_delayed_scaling_overflows%": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(
x, aux_dict["fp8_delayed_scaling"]
)
/ x.numel()
* 100,
lambda buffers: 100
* sum(_get(buffers, "fp8_delayed_scaling_overflows_num"))
/ sum(_get(buffers, "numel")),
),
"overflows_num": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]),
lambda buffers: sum(_get(buffers, "overflows_num")),
),
"overflows%": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""])
/ x.numel()
* 100,
lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")),
), ),
} }
def add_underflows_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* underflow stats (num and %) for the given recipe."""
columnwise_suffix = "_columnwise" if columnwise else ""
# Stat names
stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}"
stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
STATS[stat_num] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
- (x == 0).sum(),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
/ aux_dict[recipe_name].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
def add_scale_inv_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* scale-inv min and max stats for a given recipe.
This replaces the earlier separate helpers and avoids duplicated boilerplate.
"""
# Determine which attribute holds the scale-inverse tensor.
def get_scale_inv(quantized_tensor, columnwise):
if hasattr(quantized_tensor, "_scale_inv"):
return getattr(quantized_tensor, "_scale_inv")
if columnwise:
return getattr(quantized_tensor, "_columnwise_scale_inv")
return getattr(quantized_tensor, "_rowwise_scale_inv")
columnwise_suffix = "_columnwise" if columnwise else ""
# Prepare stat names.
stat_name_min = (
f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}"
)
stat_name_max = (
f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}"
)
# Assign indices in `stats_to_num` (order matters — keep insertion order deterministic).
stats_to_num[stat_name_min] = len(stats_to_num)
stats_to_num[stat_name_max] = len(stats_to_num)
# Capture the attribute name inside lambdas via default args to avoid late binding.
STATS[stat_name_min] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(),
lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)),
)
STATS[stat_name_max] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(),
lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)),
)
DEPENDENCIES[stat_name_min] = {stat_name_min}
DEPENDENCIES[stat_name_max] = {stat_name_max}
def add_mse_stats(recipe_name: str, columnwise: bool = False):
"""Register mse and total_square_error stats for the recipe."""
columnwise_suffix = "_columnwise" if columnwise else ""
stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}"
stat_err = (
f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}"
)
stats_to_num[stat_mse] = len(stats_to_num)
stats_to_num[stat_err] = len(stats_to_num)
STATS[stat_mse] = (
lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"),
lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err))
/ sum(_get(buffers, "numel")),
)
STATS[stat_err] = (
lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"),
lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)),
)
DEPENDENCIES[stat_err] = {stat_err}
DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"}
for _columnwise in [True, False]:
for _recipe_name in [
"", # default recipe
"fp8_delayed_scaling",
"mxfp8",
"fp8_current_scaling",
"fp8_block_scaling",
]:
add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise)
...@@ -22,6 +22,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import ( ...@@ -22,6 +22,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from transformer_engine.debug.pytorch.debug_state import TEDebugState
aten = torch.ops.aten aten = torch.ops.aten
...@@ -53,14 +54,13 @@ class DebugQuantizer(Quantizer): ...@@ -53,14 +54,13 @@ class DebugQuantizer(Quantizer):
parent_quantizer: Optional[Quantizer], parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
): ):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True) super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name self.layer_name = layer_name
self.tensor_name = tensor_name self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count self.iteration = TEDebugState.get_iteration()
# .internal = True is slightly faster, but results # .internal = True is slightly faster, but results
# in errors when caching the weights. # in errors when caching the weights.
...@@ -70,6 +70,12 @@ class DebugQuantizer(Quantizer): ...@@ -70,6 +70,12 @@ class DebugQuantizer(Quantizer):
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# next iteration when this quantizer will call any API
# it is None at the init and it is computed after_enabled api calls.
# None at the beginning means that if nothing will be done,
# this quantizer will never call any API.
self.next_debug_iter = None
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed. # rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted. # These fields indicate the path where API calls will be inserted.
...@@ -102,15 +108,21 @@ class DebugQuantizer(Quantizer): ...@@ -102,15 +108,21 @@ class DebugQuantizer(Quantizer):
""" """
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( inspect_tensor_enabled = self.process_enabled_api_call(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
) )
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name, modify_enabled = self.process_enabled_api_call(
gemm=self.rowwise_gemm_name, debug_api.transformer_engine.modify_tensor_enabled(
tensor_name=self.tensor_name, layer_name=self.layer_name,
iteration=self.iteration, gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
) )
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan return inspect_tensor_enabled, plan
...@@ -121,10 +133,13 @@ class DebugQuantizer(Quantizer): ...@@ -121,10 +133,13 @@ class DebugQuantizer(Quantizer):
""" """
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( inspect_tensor_enabled = self.process_enabled_api_call(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
) )
inspect_tensor_postquantize_enabled_rowwise = (
inspect_tensor_postquantize_enabled_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled( debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, tensor_name=self.tensor_name,
...@@ -132,7 +147,8 @@ class DebugQuantizer(Quantizer): ...@@ -132,7 +147,8 @@ class DebugQuantizer(Quantizer):
gemm=self.rowwise_gemm_name, gemm=self.rowwise_gemm_name,
) )
) )
inspect_tensor_postquantize_enabled_columnwise = (
inspect_tensor_postquantize_enabled_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled( debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, tensor_name=self.tensor_name,
...@@ -140,7 +156,6 @@ class DebugQuantizer(Quantizer): ...@@ -140,7 +156,6 @@ class DebugQuantizer(Quantizer):
gemm=self.columnwise_gemm_name, gemm=self.columnwise_gemm_name,
) )
) )
return ( return (
inspect_tensor_enabled, inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise, inspect_tensor_postquantize_enabled_rowwise,
...@@ -158,42 +173,54 @@ class DebugQuantizer(Quantizer): ...@@ -158,42 +173,54 @@ class DebugQuantizer(Quantizer):
rowwise_plan = None rowwise_plan = None
columnwise_plan = None columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( modify_rowwise = self.process_enabled_api_call(
layer_name=self.layer_name, debug_api.transformer_engine.modify_tensor_enabled(
gemm=self.rowwise_gemm_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, gemm=self.rowwise_gemm_name,
iteration=self.iteration, tensor_name=self.tensor_name,
iteration=self.iteration,
)
) )
if modify_rowwise: if modify_rowwise:
rowwise_plan = API_CALL_MODIFY rowwise_plan = API_CALL_MODIFY
else: else:
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( fp8_quantize = self.process_enabled_api_call(
layer_name=self.layer_name, debug_api.transformer_engine.fp8_gemm_enabled(
gemm=self.rowwise_gemm_name, layer_name=self.layer_name,
iteration=self.iteration, gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
) )
if fp8_quantize: if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None: if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None: if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( modify_columnwise = self.process_enabled_api_call(
layer_name=self.layer_name, debug_api.transformer_engine.modify_tensor_enabled(
gemm=self.columnwise_gemm_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, gemm=self.columnwise_gemm_name,
iteration=self.iteration, tensor_name=self.tensor_name,
iteration=self.iteration,
)
) )
if modify_columnwise: if modify_columnwise:
columnwise_plan = API_CALL_MODIFY columnwise_plan = API_CALL_MODIFY
else: else:
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( fp8_quantize = self.process_enabled_api_call(
layer_name=self.layer_name, debug_api.transformer_engine.fp8_gemm_enabled(
gemm=self.columnwise_gemm_name, layer_name=self.layer_name,
iteration=self.iteration, gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
) )
if fp8_quantize: if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None: if columnwise_plan is None:
...@@ -229,8 +256,11 @@ class DebugQuantizer(Quantizer): ...@@ -229,8 +256,11 @@ class DebugQuantizer(Quantizer):
"layer_name": self.layer_name, "layer_name": self.layer_name,
"tensor": tensor, "tensor": tensor,
"tensor_name": self.tensor_name, "tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, "iteration": TEDebugState.get_iteration(),
"tp_group": self.tp_group, "tp_group": self.tp_group,
"columnwise_quantized_tensor": columnwise_gemm_tensor,
"rowwise_quantized_tensor": rowwise_gemm_tensor,
"quantizer": self.parent_quantizer,
} }
if tensor is not None and self.inspect_tensor_enabled: if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args) debug_api.transformer_engine.inspect_tensor(**args)
...@@ -238,6 +268,10 @@ class DebugQuantizer(Quantizer): ...@@ -238,6 +268,10 @@ class DebugQuantizer(Quantizer):
if self.output_tensor: if self.output_tensor:
return return
del args["columnwise_quantized_tensor"]
del args["rowwise_quantized_tensor"]
del args["quantizer"]
if ( if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise and self.inspect_tensor_postquantize_enabled_rowwise
...@@ -245,6 +279,7 @@ class DebugQuantizer(Quantizer): ...@@ -245,6 +279,7 @@ class DebugQuantizer(Quantizer):
args["tensor"] = rowwise_gemm_tensor args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args) debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if ( if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise and self.inspect_tensor_postquantize_enabled_columnwise
...@@ -270,22 +305,14 @@ class DebugQuantizer(Quantizer): ...@@ -270,22 +305,14 @@ class DebugQuantizer(Quantizer):
# 1. If there is fp8 quantization in at least one of the gemms, # 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed. # the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise self._update_parent_quantizer_usage()
rowwise_gemm_quantize = ( # Only columnwise quantization is not supported.
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE if self.parent_quantizer is not None:
) if not self.parent_quantizer.rowwise_usage and self.parent_quantizer.columnwise_usage:
columnwise_gemm_quantize = ( self.parent_quantizer.set_usage(rowwise=True)
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor) quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed # one tensor with columnwise=True and rowwise=True is computed
...@@ -341,7 +368,6 @@ class DebugQuantizer(Quantizer): ...@@ -341,7 +368,6 @@ class DebugQuantizer(Quantizer):
quantizer=self, quantizer=self,
layer_name=self.layer_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, tensor_name=self.tensor_name,
original_tensor=tensor,
) )
def process_gemm_output(self, tensor: torch.Tensor): def process_gemm_output(self, tensor: torch.Tensor):
...@@ -375,6 +401,26 @@ class DebugQuantizer(Quantizer): ...@@ -375,6 +401,26 @@ class DebugQuantizer(Quantizer):
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device) return torch.empty(shape, dtype=dtype, device=device)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
# pylint: disable=too-many-boolean-expressions
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
def calibrate(self, tensor: torch.Tensor): def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked.""" """Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
...@@ -446,29 +492,70 @@ class DebugQuantizer(Quantizer): ...@@ -446,29 +492,70 @@ class DebugQuantizer(Quantizer):
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool: def get_next_debug_iter(self) -> Optional[int]:
"""Returns bool if there is at least one API call enabled.""" """
if self.output_tensor: Returns the next iteration for which the debug is enabled for this tensor.
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY If the next iteration is None, then the debug is not enabled for this tensor.
if ( """
self.inspect_tensor_enabled return self.next_debug_iter
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer""" """Probably not needed for debug quantizer"""
return None return None
def process_enabled_api_call(
self, enabled_call_output: bool | Tuple[bool, Optional[int]]
) -> bool:
"""
Process enabled API call output.
Updates self.next_debug_iter field accordingly.
Return the bool representing if the API call is enabled.
"""
if isinstance(enabled_call_output, tuple):
assert len(enabled_call_output) == 2, "Expected a tuple of length 2"
enabled_bool, next_iter = enabled_call_output
else:
enabled_bool = enabled_call_output
next_iter = self.iteration + 1
if self.next_debug_iter is None:
self.next_debug_iter = next_iter
elif next_iter is not None:
# If next iter is None, that means that call will never be enabled.
self.next_debug_iter = min(self.next_debug_iter, next_iter)
return enabled_bool
def supports_only_rowwise_all_gather(self) -> bool:
if self.parent_quantizer is not None:
return self.parent_quantizer.supports_only_rowwise_all_gather()
return False
def _update_parent_quantizer_usage(self):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize,
)
def set_usage(self, rowwise: bool = None, columnwise: bool = None):
"""
Sets the usage of the quantizer.
"""
super().set_usage(rowwise=rowwise, columnwise=columnwise)
if not self.output_tensor:
self._update_parent_quantizer_usage()
class DebugQuantizedTensor(QuantizedTensorBase): class DebugQuantizedTensor(QuantizedTensorBase):
""" """
...@@ -484,7 +571,6 @@ class DebugQuantizedTensor(QuantizedTensorBase): ...@@ -484,7 +571,6 @@ class DebugQuantizedTensor(QuantizedTensorBase):
quantizer, quantizer,
layer_name=None, layer_name=None,
tensor_name=None, tensor_name=None,
original_tensor=None,
): ):
self.rowwise_gemm_tensor = rowwise_gemm_tensor self.rowwise_gemm_tensor = rowwise_gemm_tensor
...@@ -492,7 +578,6 @@ class DebugQuantizedTensor(QuantizedTensorBase): ...@@ -492,7 +578,6 @@ class DebugQuantizedTensor(QuantizedTensorBase):
self.quantizer = quantizer self.quantizer = quantizer
self._layer_name = layer_name self._layer_name = layer_name
self._tensor_name = tensor_name self._tensor_name = tensor_name
self._original_tensor = original_tensor
def prepare_for_saving(self): def prepare_for_saving(self):
""" " Prepare for saving method override""" """ " Prepare for saving method override"""
...@@ -501,6 +586,7 @@ class DebugQuantizedTensor(QuantizedTensorBase): ...@@ -501,6 +586,7 @@ class DebugQuantizedTensor(QuantizedTensorBase):
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor] else [self.rowwise_gemm_tensor]
) )
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking # pylint: disable=unbalanced-tuple-unpacking
...@@ -519,6 +605,7 @@ class DebugQuantizedTensor(QuantizedTensorBase): ...@@ -519,6 +605,7 @@ class DebugQuantizedTensor(QuantizedTensorBase):
else: else:
self.rowwise_gemm_tensor = tensor_objects_list[0] self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors return saved_tensors
def quantize_(self, tensor, *, noop_flag=None): def quantize_(self, tensor, *, noop_flag=None):
...@@ -542,3 +629,27 @@ class DebugQuantizedTensor(QuantizedTensorBase): ...@@ -542,3 +629,27 @@ class DebugQuantizedTensor(QuantizedTensorBase):
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor.""" """Update usage of the tensor."""
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor:
# If the same object is used both for rowwise and columnwise gemms,
# there is no benefit in erasing the usage of one of them.
# And there are scenarios when not deleting the usage of one of them is needed.
# For example when we want to recreate columnwise from rowwise.
if rowwise_usage is False:
self.rowwise_gemm_tensor = None
if columnwise_usage is False:
self.columnwise_gemm_tensor = None
if isinstance(self.rowwise_gemm_tensor, QuantizedTensor):
self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if isinstance(self.columnwise_gemm_tensor, QuantizedTensor):
self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if rowwise_usage and self.rowwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate rowwise tensor from columnwise tensor in debug mode."
)
if columnwise_usage and self.columnwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate columnwise tensor from rowwise tensor is debug mode."
)
...@@ -62,6 +62,13 @@ class TEDebugState: ...@@ -62,6 +62,13 @@ class TEDebugState:
"""Sets weight tensor reduction mode.""" """Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled cls.weight_tensor_tp_group_reduce = enabled
@classmethod
def get_iteration(cls):
"""Returns the current iteration."""
import nvdlfw_inspect.api as debug_api
return debug_api.DEBUG_MANAGER._trainer_iteration_count
def set_weight_tensor_tp_group_reduce(enabled): def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode.""" """Sets weight tensor reduction mode."""
......
...@@ -4,6 +4,25 @@ ...@@ -4,6 +4,25 @@
"""Utils functions for the debug module.""" """Utils functions for the debug module."""
from typing import Optional
def next_iter_when_debug_should_be_run(quantizers) -> Optional[int]:
"""
Returns next iteration at which the debug should be run.
If debug will never be run for this layer, returns None.
"""
out = None
for q in quantizers:
if q.get_next_debug_iter() is not None:
if out is None:
out = q.get_next_debug_iter()
else:
out = min(out, q.get_next_debug_iter())
return out
def any_feature_enabled(quantizers): def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer.""" """Returns True if at least one API call is made from DebugQuantizer."""
......
...@@ -855,7 +855,7 @@ def fused_attn_thd( ...@@ -855,7 +855,7 @@ def fused_attn_thd(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
def _fused_attn( def _fused_attn(
qkv: Tuple[jnp.ndarray, ...], qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
...@@ -872,6 +872,7 @@ def _fused_attn( ...@@ -872,6 +872,7 @@ def _fused_attn(
context_parallel_strategy: CPStrategy, context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool, context_parallel_causal_load_balanced: bool,
context_parallel_axis: str, context_parallel_axis: str,
context_checkpoint_name: str = "context",
): ):
output, _ = _fused_attn_fwd_rule( output, _ = _fused_attn_fwd_rule(
qkv, qkv,
...@@ -889,6 +890,7 @@ def _fused_attn( ...@@ -889,6 +890,7 @@ def _fused_attn(
context_parallel_strategy, context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
) )
return output return output
...@@ -909,6 +911,7 @@ def _fused_attn_fwd_rule( ...@@ -909,6 +911,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy, context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name,
): ):
output, softmax_aux, rng_state = tex.fused_attn_fwd( output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv, qkv,
...@@ -927,9 +930,9 @@ def _fused_attn_fwd_rule( ...@@ -927,9 +930,9 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
) )
output = checkpoint_name(output, "context") output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, "context") softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
rng_state = checkpoint_name(rng_state, "context") rng_state = checkpoint_name(rng_state, context_checkpoint_name)
return output, ( return output, (
qkv, qkv,
bias, bias,
...@@ -952,9 +955,11 @@ def _fused_attn_bwd_rule( ...@@ -952,9 +955,11 @@ def _fused_attn_bwd_rule(
context_parallel_strategy, context_parallel_strategy,
context_parallel_causal_load_balanced, context_parallel_causal_load_balanced,
context_parallel_axis, context_parallel_axis,
context_checkpoint_name,
ctx, ctx,
dz, dz,
): ):
del context_checkpoint_name
( (
qkv, qkv,
bias, bias,
...@@ -1012,6 +1017,7 @@ def fused_attn( ...@@ -1012,6 +1017,7 @@ def fused_attn(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False, context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "", context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
): ):
""" """
Perform cuDNN fused attention. Perform cuDNN fused attention.
...@@ -1044,6 +1050,7 @@ def fused_attn( ...@@ -1044,6 +1050,7 @@ def fused_attn(
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
Returns: Returns:
(jnp.ndarray): The output tensor from the fused attention. (jnp.ndarray): The output tensor from the fused attention.
...@@ -1116,6 +1123,7 @@ def fused_attn( ...@@ -1116,6 +1123,7 @@ def fused_attn(
context_parallel_strategy=context_parallel_strategy, context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis, context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
) )
return output return output
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Checkpoint policies for Transformer Engine in JAX.
This module provides JAX checkpoint policies that are compatible with Transformer Engine's custom primitives.
"""
import jax
from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive
__all__ = [
"te_gemms_saveable",
"dots_and_te_gemms_with_no_batch_dims",
"checkpoint_dots_and_te_gemms",
]
def te_gemms_saveable(prim, *_, **__) -> bool:
"""Checkpoint policy for Transformer Engine GEMMs."""
is_te_gemm = prim in {GemmPrimitive.outer_primitive, GroupedGemmPrimitive.outer_primitive}
# Workaround to include JAX's scaled_matmul until JAX checkpoint policies for dots are
# updated to include it.
is_jax_scaled_matmul = prim.name == "scaled_matmul_wrapper"
return is_te_gemm or is_jax_scaled_matmul
dots_and_te_gemms_with_no_batch_dims = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
te_gemms_saveable,
)
checkpoint_dots_and_te_gemms = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots,
te_gemms_saveable,
)
...@@ -915,11 +915,11 @@ register_primitive(BaseDActLuDBiasQuantizePrimitive) ...@@ -915,11 +915,11 @@ register_primitive(BaseDActLuDBiasQuantizePrimitive)
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""JAX/TE base custom ops""" """JAX/TE base custom ops"""
import os import os
import re import re
import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from packaging import version from packaging import version
...@@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta):
name = None name = None
_is_enabled = True
# Default list of primitives to disable for all recipes
_default_disable_names = []
@classmethod @classmethod
def enabled(cls): def enabled(cls):
""" """
A custom call is marked as disabled if the `cls.__name__` does not fully match the Determines if a custom call is enabled based on a state variable and environment variables.
`NVTE_JAX_CUSTOM_CALLS_RE` pattern. Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern),
This uses the Python class name of the primitive definitions that inherit from BasePrimitive. and finally to the internal state `_is_enabled` if neither is set.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`. Environment Variables:
1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives.
- Example 1 (global enable): 'true' enables all primitives.
- Example 2 (global disable): 'false' disables all primitives.
- Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state.
Note that the default state is set at class level based on _default_disable_names.
2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names.
- Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive.
- A deprecation warning is raised if used; it will be removed in future releases.
Behavior:
1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value.
2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching.
3. If neither is set, falls back to the internal state `_is_enabled`.
""" """
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern) # Check new key/value environment variable first
is_enabled = pattern.fullmatch(cls.__name__) is not None custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS")
return is_enabled if custom_calls_str is not None:
custom_calls_str = custom_calls_str.strip()
if custom_calls_str.lower() == "true":
return True
if custom_calls_str.lower() == "false":
return False
# Parse key=value pairs
settings = {}
for pair in custom_calls_str.split(","):
pair = pair.strip()
if "=" in pair:
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip().lower()
settings[key] = value == "true"
if cls.__name__ in settings:
return settings[cls.__name__]
# Check old regex environment variable (deprecated)
pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE")
if pattern_str is not None:
warnings.warn(
"NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use"
" NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.,"
" 'DBiasQuantizePrimitive=false').",
DeprecationWarning,
)
pattern = re.compile(pattern_str)
env_enabled = pattern.fullmatch(cls.__name__) is not None
return env_enabled
# If no environment variable is set, fall back to the internal state
return cls._is_enabled
@classmethod
def set_enabled(cls, enabled: bool):
"""
Sets the enabled state for this primitive.
"""
cls._is_enabled = enabled
@staticmethod @staticmethod
@abstractmethod @abstractmethod
...@@ -109,10 +168,19 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -109,10 +168,19 @@ class BasePrimitive(metaclass=ABCMeta):
return "... -> ..." return "... -> ..."
# Registry to store all registered primitive classes
_primitive_registry = {}
def register_primitive(cls): def register_primitive(cls):
""" """
register jax primitive Register a JAX primitive and add it to the internal registry.
""" """
_primitive_registry[cls.__name__] = cls
# Set default disabled state at class level based on _default_disable_names
if cls.__name__ in BasePrimitive._default_disable_names:
cls.set_enabled(False)
def name_of_wrapper_p(): def name_of_wrapper_p():
return cls.name + "_wrapper" return cls.name + "_wrapper"
...@@ -145,3 +213,48 @@ def register_primitive(cls): ...@@ -145,3 +213,48 @@ def register_primitive(cls):
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA") ffi.register_ffi_target(_name, _value, platform="CUDA")
def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None.
disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False.
Note:
1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied.
2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last.
"""
enable_set = set(enable_names or [])
disable_set = set(disable_names or [])
if disable_all_first:
for name, cls in _primitive_registry.items():
if (
isinstance(cls, type)
and issubclass(cls, BasePrimitive)
and cls is not BasePrimitive
):
cls.set_enabled(False)
# Apply enables
for name in enable_set:
cls = _primitive_registry.get(name)
if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
cls.set_enabled(True)
else:
raise ValueError(f"Primitive not found in registry: {name}")
# Apply disables (overrides enables if there's a conflict)
for name in disable_set:
cls = _primitive_registry.get(name)
if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
cls.set_enabled(False)
else:
raise ValueError(f"Primitive not found in registry: {name}")
...@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi" name = "te_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) impl_static_args = (6, 7, 8, 9, 10, 11, 12)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -169,16 +169,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -169,16 +169,13 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
): ):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator del use_split_accumulator
def _dims_are_consecutive(dims): def _dims_are_consecutive(dims):
if len(dims) <= 1: if len(dims) <= 1:
...@@ -201,27 +198,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -201,27 +198,6 @@ class GemmPrimitive(BasePrimitive):
f"{rhs_contracting_dims}." f"{rhs_contracting_dims}."
) )
(
lhs_batch_dims,
rhs_batch_dims,
) = map(sanitize_dims, operand_ndims, batched_dims)
assert _dims_are_consecutive(lhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f"{lhs_batch_dims}."
)
assert _dims_are_consecutive(rhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f"{rhs_batch_dims}."
)
if len(lhs_batch_dims) == 0:
assert (
len(rhs_batch_dims) == 0
), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif len(rhs_batch_dims) != 0:
assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
bdim in rhs_contracting_dims for bdim in rhs_batch_dims
), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size, rhs_contracting_size = map( lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape), (lhs.shape, rhs.shape),
...@@ -335,16 +311,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -335,16 +311,14 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
): ):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype del out_dtype
lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout( lhs_transposed, rhs_transposed = _get_gemm_layout(
...@@ -385,9 +359,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -385,9 +359,6 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
...@@ -402,14 +373,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -402,14 +373,14 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_inv, lhs_scale_inv,
scaling_mode, scaling_mode,
lhs.shape, lhs.shape,
is_colwise=lhs_quantized_colwise, is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
) )
rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, rhs_scale_inv,
scaling_mode, scaling_mode,
rhs.shape, rhs.shape,
is_colwise=rhs_quantized_colwise, is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
) )
...@@ -422,9 +393,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -422,9 +393,6 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
...@@ -436,12 +404,9 @@ class GemmPrimitive(BasePrimitive): ...@@ -436,12 +404,9 @@ class GemmPrimitive(BasePrimitive):
@staticmethod @staticmethod
def batcher( def batcher(
batched_args, batched_args,
jax_batch_dims, batch_dims,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
...@@ -449,24 +414,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -449,24 +414,13 @@ class GemmPrimitive(BasePrimitive):
use_split_accumulator, use_split_accumulator,
): ):
assert GemmPrimitive.outer_primitive is not None assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args lhs_bdims, _, rhs_bdims, *_ = batch_dims
lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
# Output is batched like the non-contracting batch dimensions of the LHS operand # Batched GEMM is not supported
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) assert (
lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) lhs_bdims is None and rhs_bdims is None
out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})"
out_bdims = (None,)
# Bias gradient is never batched # Bias gradient is never batched
bias_bdims = (None,) bias_bdims = (None,)
...@@ -481,9 +435,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -481,9 +435,6 @@ class GemmPrimitive(BasePrimitive):
*batched_args, *batched_args,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
...@@ -494,168 +445,85 @@ class GemmPrimitive(BasePrimitive): ...@@ -494,168 +445,85 @@ class GemmPrimitive(BasePrimitive):
) )
@staticmethod @staticmethod
def _decompose_operand_specs(specs, contracting_dims, batch_dims): def _parse_operand_output_specs(
ndims = len(specs) arg_infos,
cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) contracting_dims,
):
# Batch specs
bspecs = tuple(specs[i] for i in bdims)
# Non-batch leading dimension specs
lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims)
# Non-batch contracting dimension specs
cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims)
return bspecs, lspecs, cspecs
@staticmethod
def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims lhs_non_cdims, rhs_non_cdims = map(
) lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims),
( (lhs_ndim, rhs_ndim),
(lhs_bspecs, lhs_lspecs, lhs_cspecs),
(rhs_bspecs, rhs_lspecs, rhs_cspecs),
) = map(
GemmPrimitive._decompose_operand_specs,
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims), (lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims), )
lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map(
lambda specs, dims: tuple(specs[i] for i in dims),
(lhs_specs, lhs_specs, rhs_specs, rhs_specs),
(lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims),
) )
# Batched dimensions must have the same sharding reduce_spec = None
if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: for l in lhs_cspecs:
assert all( for r in rhs_cspecs:
lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) if l is not None and l == r:
), ( assert reduce_spec is None, "Multiple reduce dimension is detected!"
"cuBLAS GEMM operand batch dimensions must have the same sharding: " reduce_spec = l
f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}."
if reduce_spec is not None:
# Other non-reduce cdims (if exists) need to be unsharded
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
# Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
# No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
# In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`.
rhs_non_cspecs = tuple(
None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
) )
# Only one each of the non-batched leading dimensions and non-batched contracting else:
# dimensions can be sharded # Otherwise, require contracting dims of both operands to be unsharded
lhs_ldims, rhs_ldims = map( lhs_cspecs = (None,) * len(lhs_cspecs)
lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), rhs_cspecs = (None,) * len(rhs_cspecs)
(lhs_ndim, rhs_ndim),
(lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims),
)
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map(
lambda specs: tuple(spec for spec in specs if spec is not None),
(lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs),
)
assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched leading dimension: "
f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}."
)
assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: "
f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}."
)
# Extract single leading and contracting dimension specs # Non-contracting dims of RHS always needs to be gathered along the FSDP axis
(lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( rhs_non_cspecs = tuple(
lambda specs: None if len(specs) == 0 else specs[0], None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), )
)
# Non-contracting dims of LHS to be gathered along the SP axis.
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts out_specs = lhs_non_cspecs + rhs_non_cspecs
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands.
# 1. K1 == K2 != None and N == None # specs = merge(cspecs, non_cspecs)
# LHS: (B, M, K) lhs_specs, rhs_specs = map(
# RHS: (B, None, K) lambda cdims, cspecs, non_cspecs: (
# OUT: (B, M, None) --(AR)-> (B, M, None) cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs
# 2. K1 == K2 != None and M == N != None ),
# LHS: (B, M, K)
# RHS: (B, N, K)--(AG)->(B, None, K)
# OUT: (B, M, None) --(RS)--> (B, M, N)
# 3. M == N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, M, K)--(AG)->(B, None, None)
# OUT: (B, M, None)
# 4. M != N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, N, K)--(AG)->(B, N, None)
# OUT: (B, M, N)
reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec
all_reduce_output = reduce_flag and rhs_lspec is None
reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec
all_reduce_spec = reduce_scatter_spec = scatter_dim = None
lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims), (lhs_cdims, rhs_cdims),
(lhs_cspecs, rhs_cspecs),
(lhs_non_cspecs, rhs_non_cspecs),
) )
out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs)
if reduce_scatter_output:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec = tuple(
rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim)
)
reduce_scatter_spec = lhs_cspec
scatter_dim = out_specs.index(rhs_lspec)
elif all_reduce_output:
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
all_reduce_spec = lhs_cspec
else:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
lhs_specs = tuple(
None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i]
for i in range(lhs_ndim)
)
rhs_specs = tuple(
None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i]
for i in range(rhs_ndim)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if rhs_lspec is not None and rhs_lspec in tuple(
lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
# Bias and Pre-GeLU sharding is based on GEMM output # Bias and Pre-GeLU sharding is based on GEMM output before any scatter
bias_specs = out_specs[len(lhs_non_contracting_specs) :] bias_specs = tuple(list(rhs_non_cspecs).copy())
gelu_specs = out_specs gelu_specs = tuple(list(out_specs).copy())
return ( return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs), (lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs),
all_reduce_spec, reduce_spec,
reduce_scatter_spec,
scatter_dim,
) )
@staticmethod @staticmethod
def infer_sharding_from_operands( def infer_sharding_from_operands(
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
...@@ -667,15 +535,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -667,15 +535,13 @@ class GemmPrimitive(BasePrimitive):
): ):
del ( del (
out_dtype, out_dtype,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
grad, grad,
) )
del use_split_accumulator, result_infos del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( (_, (out_specs, dbias_specs, pre_gelu_specs), _) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
...@@ -695,9 +561,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -695,9 +561,6 @@ class GemmPrimitive(BasePrimitive):
def partition( def partition(
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
...@@ -712,10 +575,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -712,10 +575,8 @@ class GemmPrimitive(BasePrimitive):
( (
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs), (out_specs, dbias_specs, pre_gelu_specs),
all_reduce_spec, reduce_spec,
reduce_scatter_spec, ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
scatter_dim,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
# Assemble argument shardings # Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
...@@ -762,9 +623,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -762,9 +623,6 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
...@@ -772,19 +630,9 @@ class GemmPrimitive(BasePrimitive): ...@@ -772,19 +630,9 @@ class GemmPrimitive(BasePrimitive):
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
) )
# All-Reduce/Reduce-Scatter GEMM output # All-Reduce GEMM output
if all_reduce_spec is not None: if reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec)
elif reduce_scatter_spec is not None:
outputs[0] = jax.lax.psum_scatter(
outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum_scatter(
outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
return outputs return outputs
...@@ -794,9 +642,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -794,9 +642,6 @@ class GemmPrimitive(BasePrimitive):
def shardy_sharding_rule( def shardy_sharding_rule(
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
...@@ -806,40 +651,33 @@ class GemmPrimitive(BasePrimitive): ...@@ -806,40 +651,33 @@ class GemmPrimitive(BasePrimitive):
operand_types, operand_types,
result_types, result_types,
): ):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator del out_dtype, grad, use_split_accumulator
del mesh, result_types del mesh, result_types
prefix = "GemmPrimitive_" prefix = "GemmPrimitive_"
def _generate_operand_rules(name, ndim, cdims, bdims): def _generate_operand_rules(name, ndim, cdims):
specs = [] specs = []
ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) ldims = tuple(i for i in range(ndim) if i not in cdims)
for i in range(ndim): for i in range(ndim):
dim_name = None dim_name = None
if i in bdims: if i in cdims:
dim_idx = bdims.index(i) if len(bdims) > 1 else "" dim_idx = cdims.index(i)
dim_name = f"b{dim_idx}"
elif i in cdims:
dim_idx = cdims.index(i) if len(cdims) > 1 else ""
dim_name = f"k{dim_idx}" dim_name = f"k{dim_idx}"
else: else:
dim_idx = ldims.index(i) if len(ldims) > 1 else "" dim_idx = ldims.index(i)
dim_name = f"{name}_l{dim_idx}" dim_name = f"{name}_l{dim_idx}"
specs.append(prefix + dim_name) specs.append(prefix + dim_name)
return specs return specs
lhs, _, rhs, *_ = operand_types lhs, _, rhs, *_ = operand_types
operand_ndims = (len(lhs.shape), len(rhs.shape)) operand_ndims = (len(lhs.shape), len(rhs.shape))
(lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( (lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims)
lambda dims: map(sanitize_dims, operand_ndims, dims),
(contracting_dims, batched_dims),
)
lhs_specs, rhs_specs = map( lhs_specs, rhs_specs = map(
_generate_operand_rules, _generate_operand_rules,
("lhs", "rhs"), ("lhs", "rhs"),
operand_ndims, operand_ndims,
(lhs_cdims, rhs_cdims), (lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
) )
lhs_scale_specs = ("…1",) lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",) rhs_scale_specs = ("…2",)
...@@ -891,7 +729,6 @@ def _te_gemm( ...@@ -891,7 +729,6 @@ def _te_gemm(
lhs_quantizer: Quantizer = None, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
fuse_bias: bool = False, fuse_bias: bool = False,
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
...@@ -906,7 +743,6 @@ def _te_gemm( ...@@ -906,7 +743,6 @@ def _te_gemm(
scaling_mode = ScalingMode.NO_SCALING scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
# Quantize operands (if necessary) # Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
...@@ -925,7 +761,6 @@ def _te_gemm( ...@@ -925,7 +761,6 @@ def _te_gemm(
lhs_scale_inv = lhs_q.scale_inv lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T": if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
if isinstance(rhs_q, ScaledTensor): if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
...@@ -943,7 +778,6 @@ def _te_gemm( ...@@ -943,7 +778,6 @@ def _te_gemm(
rhs_scale_inv = rhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T": if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
# Dummy empties for bias and gelu # Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
...@@ -961,9 +795,6 @@ def _te_gemm( ...@@ -961,9 +795,6 @@ def _te_gemm(
gelu_input, gelu_input,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims), contracting_dims=(lhs_cdims, rhs_cdims),
batched_dims=(lhs_bdims, rhs_bdims),
lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
...@@ -1171,10 +1002,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): ...@@ -1171,10 +1002,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
if rhs.data_layout == "T": if rhs.data_layout == "T":
rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
...@@ -1264,7 +1093,7 @@ def _jax_gemm( ...@@ -1264,7 +1093,7 @@ def _jax_gemm(
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}")
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
...@@ -1286,7 +1115,6 @@ def gemm( ...@@ -1286,7 +1115,6 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
lhs_quantizer: Quantizer = None, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None,
**kwargs, **kwargs,
...@@ -1305,11 +1133,6 @@ def gemm( ...@@ -1305,11 +1133,6 @@ def gemm(
Object for down-casting the RHS operand for quantized GEMM. Object for down-casting the RHS operand for quantized GEMM.
contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
Tuple of sequences representing the contracting dimensions of the operands. Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially
undesirable reduction in any batched contracting dimensions when invoked with sharded
operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM. with TE's custom call to cuBLAS GEMM.
...@@ -1327,7 +1150,8 @@ def gemm( ...@@ -1327,7 +1150,8 @@ def gemm(
TE's custom call to cuBLAS GEMM. TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM.
Returns Returns
------- -------
...@@ -1358,12 +1182,12 @@ def gemm( ...@@ -1358,12 +1182,12 @@ def gemm(
if not GemmPrimitive.enabled(): if not GemmPrimitive.enabled():
assert kwargs.get("bias", None) is None and not fuse_gelu, ( assert kwargs.get("bias", None) is None and not fuse_gelu, (
"TE GEMM was invoked with bias fusion options that are not supported by the " "TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled." "GEMM primitive is disabled."
) )
assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
"TE GEMM was invoked with GeLU fusion options that are not supported by the " "TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled." "GEMM primitive is disabled."
) )
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
...@@ -1374,7 +1198,6 @@ def gemm( ...@@ -1374,7 +1198,6 @@ def gemm(
lhs_quantizer=lhs_quantizer, lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer, rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
**kwargs, **kwargs,
) )
......
...@@ -519,11 +519,11 @@ register_primitive(BaseDBiasQuantizePrimitive) ...@@ -519,11 +519,11 @@ register_primitive(BaseDBiasQuantizePrimitive)
class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive): class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class QuantizePrimitive(BaseDBiasQuantizePrimitive): class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_quantize( def _jax_quantize(
......
...@@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, ...@@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans") .Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias") .Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"), .Attr<bool>("is_grouped_dense_wgrad"));
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, ...@@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
.Ret<Buffer_Type>() // amax .Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout") .Attr<int64_t>("q_layout")
.Attr<int64_t>("flatten_axis"), .Attr<int64_t>("flatten_axis"));
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation. ...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations. customizable contracting dimensions for flexible tensor operations.
""" """
import warnings
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import jax import jax
...@@ -23,16 +23,6 @@ from .quantize import ( ...@@ -23,16 +23,6 @@ from .quantize import (
) )
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense( def dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -40,7 +30,6 @@ def dense( ...@@ -40,7 +30,6 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -54,7 +43,6 @@ def dense( ...@@ -54,7 +43,6 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
...@@ -69,13 +57,34 @@ def dense( ...@@ -69,13 +57,34 @@ def dense(
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense( output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) @partial(
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
def _dense(
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -89,19 +98,30 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir ...@@ -89,19 +98,30 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule( output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
) )
return output return output
def _dense_fwd_rule( def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
quantizer_set,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -119,23 +139,6 @@ def _dense_fwd_rule( ...@@ -119,23 +139,6 @@ def _dense_fwd_rule(
not x_is_transposed and not k_is_transposed not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims) flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
...@@ -158,7 +161,6 @@ def _dense_fwd_rule( ...@@ -158,7 +161,6 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -175,13 +177,12 @@ def _dense_fwd_rule( ...@@ -175,13 +177,12 @@ def _dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) )
return output, ctx return output, ctx
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
...@@ -196,7 +197,6 @@ def _dense_bwd_rule( ...@@ -196,7 +197,6 @@ def _dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) = ctx ) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map( fwd_x_contracting_dims, fwd_k_contracting_dims = map(
...@@ -220,11 +220,11 @@ def _dense_bwd_rule( ...@@ -220,11 +220,11 @@ def _dense_bwd_rule(
k_contracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
...@@ -238,7 +238,6 @@ def _dense_bwd_rule( ...@@ -238,7 +238,6 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -15,12 +15,14 @@ from jax import lax ...@@ -15,12 +15,14 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from ..dense import dense, _issue_batch_first_warning as _dense_warning from transformer_engine.common import recipe
from ..dense import dense
from ..layernorm import canonicalize_norm_type from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning from ..layernorm_mlp import layernorm_mlp
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -31,7 +33,6 @@ from ..cpp_extensions import ( ...@@ -31,7 +33,6 @@ from ..cpp_extensions import (
jax_scaled_upper_triang_masked_softmax, jax_scaled_upper_triang_masked_softmax,
) )
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -274,10 +275,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -274,10 +275,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
...@@ -288,7 +285,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -288,7 +285,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",) bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
...@@ -343,21 +339,28 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -343,21 +339,28 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Base class of transformer engine Base class of transformer engine
""" """
def generate_quantizer_set(self, postfix: str = ""): def generate_quantizer_set(
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
):
""" """
Generate a set of FP8 meta for a GEMM. Generate a set of FP8 meta for a GEMM.
""" """
def generate_quantize_meta(quantizer_name: str): def generate_quantize_meta(quantizer_name: str):
collection_name = (
variable_collection
if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME
)
scale = self.variable( scale = self.variable(
QuantizeConfig.COLLECTION_NAME, collection_name,
f"{quantizer_name}{postfix}_scale", f"{quantizer_name}{postfix}_scale",
jnp.ones, jnp.ones,
(1,), (1,),
jnp.float32, jnp.float32,
).value ).value
amax_history = self.variable( amax_history = self.variable(
QuantizeConfig.COLLECTION_NAME, collection_name,
f"{quantizer_name}{postfix}_amax_history", f"{quantizer_name}{postfix}_amax_history",
jnp.zeros, jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,), (QuantizeConfig.AMAX_HISTORY_LEN,),
...@@ -365,7 +368,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -365,7 +368,9 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).value ).value
return QuantizeMeta(scale=scale, amax_history=amax_history) return QuantizeMeta(scale=scale, amax_history=amax_history)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
fp8_recipe, recipe.DelayedScaling
):
x_meta = generate_quantize_meta("x") x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel") kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad") grad_meta = generate_quantize_meta("grad")
...@@ -374,7 +379,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method ...@@ -374,7 +379,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
else: else:
kwargs = {} kwargs = {}
quantizer_set = QuantizerFactory.create_set(**kwargs) quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs)
return quantizer_set return quantizer_set
...@@ -420,10 +425,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -420,10 +425,6 @@ class DenseGeneral(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -437,16 +438,9 @@ class DenseGeneral(TransformerEngineBase): ...@@ -437,16 +438,9 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
...@@ -628,10 +622,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -628,10 +622,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
...@@ -657,18 +647,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,18 +647,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, 1.0,
...@@ -936,15 +919,16 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -936,15 +919,16 @@ class LayerNormMLP(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input of 2nd dot, like Indicate the logical axes of sharding constraint to the input of 2nd dot, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint. sharding constraint.
ffn1_ckpt_name: str = "ffn1"
Checkpoint name for the output of the first fully-connected layer in the MLP block.
ffn2_ckpt_name: str = "ffn2"
Checkpoint name for the output of the second fully-connected layer in the MLP block.
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -973,18 +957,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -973,18 +957,13 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
ffn1_ckpt_name: str = "ffn1"
ffn2_ckpt_name: str = "ffn2"
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
...@@ -1146,9 +1125,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1146,9 +1125,6 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None bias_1 = None
bias_2 = None bias_2 = None
ffn1_ckpt_name = "ffn1"
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp: if use_fused_layernorm_mlp:
out = layernorm_mlp( out = layernorm_mlp(
y, y,
...@@ -1164,8 +1140,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1164,8 +1140,8 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1, kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2, kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=self.ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set),
) )
...@@ -1198,15 +1174,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1198,15 +1174,6 @@ class LayerNormMLP(TransformerEngineBase):
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(
kernel_1.ndim, self.kernel_axes_1, contract_ind
),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wi_lora_a_kernel_each_shape = ( wi_lora_a_kernel_each_shape = (
kernel_1_each_shape[: len(axis)], kernel_1_each_shape[: len(axis)],
...@@ -1247,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1247,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, self.ffn1_ckpt_name)
if is_act_implemented: if is_act_implemented:
z = activation(x, normalized_acts) z = activation(x, normalized_acts)
else: else:
...@@ -1310,7 +1277,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1310,7 +1277,7 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name) out = checkpoint_name(out, self.ffn2_ckpt_name)
assert out.dtype == input_dtype assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layner_norm_output
...@@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -322,6 +323,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -322,6 +323,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
) )
elif self.qkv_layout.is_kvpacked(): elif self.qkv_layout.is_kvpacked():
"""kvpacked format, treat """kvpacked format, treat
...@@ -348,6 +350,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -348,6 +350,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
) )
elif self.qkv_layout.is_separate(): elif self.qkv_layout.is_separate():
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -369,6 +372,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -369,6 +372,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
) )
else: else:
raise ValueError(f"Unsupported {self.qkv_layout=}.") raise ValueError(f"Unsupported {self.qkv_layout=}.")
...@@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_checkpoint_name: str = "context"
@nn.compact @nn.compact
def __call__( def __call__(
...@@ -690,6 +696,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -690,6 +696,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_checkpoint_name=self.context_checkpoint_name,
)( )(
query, query,
key, key,
...@@ -1160,7 +1167,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1160,7 +1167,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=(3, self.num_attention_heads * self.head_dim), features=(3, self.num_attention_heads * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.return_layernorm_output, return_layernorm_output=self.return_layernorm_output,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
...@@ -1187,7 +1193,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1187,7 +1193,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_attention_heads * self.head_dim, features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=(self.return_layernorm_output or is_self_attn), return_layernorm_output=(self.return_layernorm_output or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
...@@ -1212,7 +1217,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1212,7 +1217,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_proj = DenseGeneral( kv_proj = DenseGeneral(
axis=-1, axis=-1,
features=(2, self.num_gqa_groups * self.head_dim), features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -1231,7 +1235,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1231,7 +1235,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral, DenseGeneral,
axis=-1, axis=-1,
features=self.num_gqa_groups * self.head_dim, features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
...@@ -1248,7 +1251,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1248,7 +1251,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_attention_heads * self.head_dim, features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
...@@ -1413,7 +1415,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1413,7 +1415,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
out = DenseGeneral( out = DenseGeneral(
features=inputs_q.shape[-1], features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=(W_TP_AXES, W_FSDP_AXES), kernel_axes=(W_TP_AXES, W_FSDP_AXES),
...@@ -2015,7 +2016,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2015,7 +2016,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
...@@ -2070,7 +2070,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2070,7 +2070,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
name="output_layernorm", name="output_layernorm",
)(z) )(z)
......
...@@ -9,7 +9,6 @@ architectures. It supports various normalization types, quantization, and ...@@ -9,7 +9,6 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints. distributed training through sharding constraints.
""" """
import warnings
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
...@@ -26,16 +25,6 @@ from .quantize import ( ...@@ -26,16 +25,6 @@ from .quantize import (
) )
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense( def layernorm_dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -48,7 +37,6 @@ def layernorm_dense( ...@@ -48,7 +37,6 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -69,7 +57,6 @@ def layernorm_dense( ...@@ -69,7 +57,6 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -93,7 +80,6 @@ def layernorm_dense( ...@@ -93,7 +80,6 @@ def layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -108,7 +94,6 @@ def layernorm_dense( ...@@ -108,7 +94,6 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -123,7 +108,6 @@ def _layernorm_dense( ...@@ -123,7 +108,6 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -143,7 +127,6 @@ def _layernorm_dense( ...@@ -143,7 +127,6 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
Returns: Returns:
...@@ -161,7 +144,6 @@ def _layernorm_dense( ...@@ -161,7 +144,6 @@ def _layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -179,7 +161,6 @@ def _layernorm_dense_fwd_rule( ...@@ -179,7 +161,6 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -197,17 +178,6 @@ def _layernorm_dense_fwd_rule( ...@@ -197,17 +178,6 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd( casted_ln_out, mu, rsigma = tex.normalization_fwd(
...@@ -236,7 +206,6 @@ def _layernorm_dense_fwd_rule( ...@@ -236,7 +206,6 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -260,7 +229,6 @@ def _layernorm_dense_fwd_rule( ...@@ -260,7 +229,6 @@ def _layernorm_dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) )
return output, ctx return output, ctx
...@@ -271,9 +239,8 @@ def _layernorm_dense_bwd_rule( ...@@ -271,9 +239,8 @@ def _layernorm_dense_bwd_rule(
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes,
kernel_axes, kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx, ctx,
grad, grad,
): ):
...@@ -288,6 +255,7 @@ def _layernorm_dense_bwd_rule( ...@@ -288,6 +255,7 @@ def _layernorm_dense_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del dot_input_axes
( (
casted_ln_out, casted_ln_out,
casted_kernel, casted_kernel,
...@@ -303,7 +271,6 @@ def _layernorm_dense_bwd_rule( ...@@ -303,7 +271,6 @@ def _layernorm_dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
...@@ -328,7 +295,6 @@ def _layernorm_dense_bwd_rule( ...@@ -328,7 +295,6 @@ def _layernorm_dense_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -342,7 +308,6 @@ def _layernorm_dense_bwd_rule( ...@@ -342,7 +308,6 @@ def _layernorm_dense_bwd_rule(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -13,7 +13,6 @@ The implementation supports various normalization types, activation functions, ...@@ -13,7 +13,6 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints. quantization, and distributed training through sharding constraints.
""" """
import warnings
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
...@@ -29,17 +28,6 @@ from .quantize import ( ...@@ -29,17 +28,6 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
) )
from .sharding import get_non_contracting_logical_axes
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp( def layernorm_mlp(
...@@ -59,7 +47,6 @@ def layernorm_mlp( ...@@ -59,7 +47,6 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -91,7 +78,6 @@ def layernorm_mlp( ...@@ -91,7 +78,6 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -137,13 +123,12 @@ def layernorm_mlp( ...@@ -137,13 +123,12 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -163,7 +148,6 @@ def _layernorm_mlp( ...@@ -163,7 +148,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -189,7 +173,6 @@ def _layernorm_mlp( ...@@ -189,7 +173,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -214,7 +197,6 @@ def _layernorm_mlp( ...@@ -214,7 +197,6 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -239,7 +221,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -239,7 +221,6 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -256,7 +237,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -256,7 +237,7 @@ def _layernorm_mlp_fwd_rule(
Returns: Returns:
Tuple of (output, context) for automatic differentiation Tuple of (output, context) for automatic differentiation
""" """
del kernel_2_axes del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -272,17 +253,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -272,17 +253,6 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -310,18 +280,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -310,18 +280,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
) )
if dot_1_input_axes is not None and kernel_1_axes is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1 and tex.gemm_uses_jax_dot(): if use_bias_1 and tex.gemm_uses_jax_dot():
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
...@@ -346,7 +308,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -346,7 +308,6 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -376,7 +337,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -376,7 +337,6 @@ def _layernorm_mlp_fwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
) )
return dot_2_output, ctx return dot_2_output, ctx
...@@ -394,7 +354,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -394,7 +354,6 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
ctx, ctx,
grad, grad,
): ):
...@@ -411,7 +370,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -411,7 +370,7 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
( (
x, x,
mu, mu,
...@@ -430,7 +389,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -430,7 +389,6 @@ def _layernorm_mlp_bwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -500,7 +456,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -500,7 +456,6 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -511,7 +466,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -511,7 +466,6 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -16,12 +16,14 @@ import jax ...@@ -16,12 +16,14 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import ( from .helper import (
QuantizeConfig, QuantizeConfig,
AmaxComputeAlgo, AmaxComputeAlgo,
_get_scaling_mode,
) )
from .device_utils import is_fp8_gemm_with_all_layouts_supported from .device_utils import is_fp8_gemm_with_all_layouts_supported
...@@ -878,11 +880,12 @@ class QuantizerFactory: ...@@ -878,11 +880,12 @@ class QuantizerFactory:
@staticmethod @staticmethod
def create_set( def create_set(
n_quantizer_sets: int = 1, n_quantizer_sets: int = 1,
scaling_mode: ScalingMode = None, scaling_mode: Optional[ScalingMode] = None,
fwd_dtype: jnp.dtype = None, fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None, bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None, is_2x2x: bool = None,
n_groups: int = None, n_groups: int = None,
fp8_recipe: Optional[recipe.Recipe] = None,
**kwargs, **kwargs,
) -> tuple[Union[tuple[Quantizer], None]]: ) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers. """Create one or more sets of quantizers.
...@@ -894,12 +897,25 @@ class QuantizerFactory: ...@@ -894,12 +897,25 @@ class QuantizerFactory:
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups: n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
A single quantizer set or tuple of quantizer sets A single quantizer set or tuple of quantizer sets
""" """
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
assert scaling_mode is None or fp8_recipe is None, (
"Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling"
" mode can be specified directly via the scaling_mode parameter or indirectly via"
" recipe. Recipe is preferred as it will support additional recipes in future where"
" scaling mode differs between x, kernel, and grad in the quantizer set."
)
if fp8_recipe is not None:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
scaling_mode = _get_scaling_mode(fp8_recipe)
else:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
if is_2x2x is None: if is_2x2x is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment