Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
......@@ -37,8 +37,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output,
dbias, workspace, stream);
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
workspace, nullptr, stream);
}
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
......@@ -46,6 +46,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
NVTE_API_CALL(nvte_quantize_noop);
using namespace transformer_engine;
// Create config with noop tensor
QuantizationConfig quant_config;
quant_config.noop_tensor = noop;
nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream);
}
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
......@@ -53,8 +65,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output,
dbias, workspace, stream);
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
input, grad, output, dbias, workspace, quant_config, stream);
}
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
......@@ -68,7 +80,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -82,7 +94,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
......@@ -96,7 +108,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -110,7 +122,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -124,7 +136,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
......@@ -138,7 +150,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream);
activation_input, input, output, dbias, workspace, nullptr, stream);
}
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......
......@@ -99,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t out_gate_mem = buff_size_aligned_out;
constexpr size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
......@@ -111,7 +109,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
// uint64_t *mbar = reinterpret_cast<uint64_t *>(dshmem + grad_mem + in_mem + out_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
......@@ -294,7 +291,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
......@@ -839,8 +835,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1;
size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
e8m0_t *const scales_rowwise_ptr =
USE_ROWWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr;
e8m0_t *const scales_colwise_ptr =
......
......@@ -145,7 +145,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1);
const bool is_master_thread = (threadIdx.x == 0);
......@@ -518,7 +517,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
__shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1);
const bool is_master_thread = (threadIdx.x == 0);
......@@ -940,7 +938,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
const auto &input_shape = input.data.shape;
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
if (use_rowwise_scaling) {
......@@ -1250,9 +1247,9 @@ namespace detail {
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output,
NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
const Tensor *input_tensor;
const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) {
......@@ -1267,6 +1264,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
switch (output_tensor->scaling_mode) {
......@@ -1294,6 +1297,36 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
workspace_tensor, stream);
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data()
? FP8BlockwiseRowwiseOption::ROWWISE
: FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option =
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE
: FP8BlockwiseColumnwiseOption::NONE;
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
......
......@@ -59,7 +59,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t scales_stride) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
......@@ -68,8 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
......@@ -357,6 +355,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
} else {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
......
......@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
: "memory");
}
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
......@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
: "memory");
}
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() {
asm volatile("cp.async.bulk.wait_group 0;");
......@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
asm volatile("cp.async.bulk.wait_group.read 4;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// Proxy fence (bi-directional):
__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); }
__device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;");
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package for numerical debugging."""
try:
from . import pytorch
from .pytorch.debug_state import set_weight_tensor_tp_group_reduce
except ImportError as e:
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains DebugQuantizer and DebugQuantizedTensor objects,
which are wrappers over Quantizer and QuantizedTensor.
These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
aten = torch.ops.aten
_tensor_to_gemm_names_map = {
"weight": ["fprop", "dgrad"],
"activation": ["fprop", "wgrad"],
"output": ["fprop", None],
"gradient": ["dgrad", "wgrad"],
"wgrad": ["wgrad", None],
"dgrad": ["dgrad", None],
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize"
HIGH_PRECISION = "High Precision"
class DebugQuantizer(Quantizer):
"""
DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect.
It allows adding custom calls inside the quantization process - which enables modifying tensors
or gathering tensor stats.
"""
def __init__(
self,
layer_name: str,
tensor_name: str,
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
#
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
self.inspect_tensor_enabled, self.rowwise_tensor_plan = (
self.get_plans_for_output_tensors()
)
else:
(
self.inspect_tensor_enabled,
self.inspect_tensor_postquantize_enabled_rowwise,
self.inspect_tensor_postquantize_enabled_columnwise,
) = self.get_enabled_look_at_tensors()
self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan()
self.log_messages_about_plans()
def get_plans_for_output_tensors(self) -> Tuple[bool, str]:
"""
Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the
API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support
gemm output in FP8.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
def get_enabled_look_at_tensors(self):
"""
Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
inspect_tensor_postquantize_enabled_rowwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
inspect_tensor_postquantize_enabled_columnwise,
)
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
rowwise_plan = None
columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
return rowwise_plan, columnwise_plan
def log_messages_about_plans(self):
"""
Logs the messages about the plans for each of the tensors.
"""
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -"
f" {self.rowwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name),
)
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -"
f" {self.columnwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name),
)
def _call_inspect_tensor_api(
self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None
):
import nvdlfw_inspect.api as debug_api
args = {
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count,
"tp_group": self.tp_group,
}
if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
if self.output_tensor:
return
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
args["rowwise"] = False
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None,
dtype: torch.dtype = None,
):
"""Returns DebugQuantizedTensor object."""
import nvdlfw_inspect.api as debug_api
assert not self.output_tensor
if out is not None:
return self.update_quantized(tensor, self)
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
if self.columnwise_tensor_plan == API_CALL_MODIFY:
columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION:
rowwise_gemm_tensor = tensor.to(dtype)
if self.columnwise_tensor_plan == HIGH_PRECISION:
columnwise_gemm_tensor = tensor.to(dtype)
self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor)
# sometimes we may want to return simple tensor with only rowwise_gemm
if self.tensor_name in ["wgrad", "dgrad", "output"]:
return rowwise_gemm_tensor
return DebugQuantizedTensor(
rowwise_gemm_tensor=rowwise_gemm_tensor,
columnwise_gemm_tensor=columnwise_gemm_tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)
def process_gemm_output(self, tensor: torch.Tensor):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
gemm=tensor_to_gemm[self.tensor_name],
tensor_name=self.tensor_name,
tensor=tensor,
iteration=self.iteration,
default_quantizer=self.parent_quantizer,
)
self._call_inspect_tensor_api(tensor)
return tensor
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Override make_empty() from Quantizer class."""
if self.parent_quantizer is not None:
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update quantized tensor - used in weight caching."""
import nvdlfw_inspect.api as debug_api
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
updated_rowwise_gemm = False
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None)
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None)
if self.columnwise_tensor_plan == API_CALL_MODIFY:
out = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.columnwise_gemm_tensor,
iteration=self.iteration,
)
assert out is None, (
"API call debug_api.transformer_engine.modify_tensor with out != None should"
" return None"
)
if self.rowwise_tensor_plan == API_CALL_MODIFY:
debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.rowwise_gemm_tensor,
iteration=self.iteration,
)
if self.rowwise_tensor_plan == HIGH_PRECISION:
dst.rowwise_gemm_tensor.copy_(src)
if self.columnwise_tensor_plan == HIGH_PRECISION:
# if they are the same tensor object, it is sufficient to update one
if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor:
dst.columnwise_gemm_tensor.copy_(src)
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
class DebugQuantizedTensor:
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
get_tensor().
"""
def __init__(
self,
rowwise_gemm_tensor,
columnwise_gemm_tensor,
quantizer,
layer_name=None,
tensor_name=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
self.columnwise_gemm_tensor = columnwise_gemm_tensor
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
def prepare_for_saving(self):
""" " Prepare for saving method override"""
self.tensors_to_save = (
[self.rowwise_gemm_tensor, self.columnwise_gemm_tensor]
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
return tensor_list, self
def restore_from_saved(self, tensors):
"""Restore from saved method override"""
tensor_objects_list, saved_tensors = restore_from_saved(
self.tensors_to_save,
tensors,
return_saved_tensors=True,
)
if len(tensor_objects_list) == 2:
# pylint: disable=unbalanced-tuple-unpacking
self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
""" " quantize_ method override"""
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
self.quantizer.update_quantized(tensor, self)
def dequantize(self, *, dtype=None):
""" " dequantize method override"""
if dtype is None:
dtype = self.rowwise_gemm_tensor.dtype
return self.rowwise_gemm_tensor.dequantize().to(dtype)
def get_tensor(self, transpose: bool):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self):
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool):
"""Update usage of the tensor."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Managing the state of all the debugged layers.
"""
import sys
class TEDebugState:
"""
A class to manage the state of debug layers.
"""
layer_count = 1
layers_initialized = {}
weight_tensor_tp_group_reduce = True
debug_enabled = None
@classmethod
def initialize(cls):
"""
If debug_api module is initialized, then sets cls.debug_enabled to True.
"""
if "nvdlfw_inspect" in sys.modules:
import nvdlfw_inspect.api as debug_api
if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None:
# This method is invoked when initializing TE modules.
# If this error is thrown, it means that some TE module had been initialized before
# debug_api was initialized, and now a new TE module is being initialized.
# This is likely to be a bug.
raise RuntimeError(
"[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before"
" initialization of the first TE module"
)
cls.debug_enabled = debug_api.DEBUG_MANAGER is not None
@classmethod
def _reset(cls):
"""Resets layer count and stats buffers."""
from ..features.utils.stats_buffer import STATS_BUFFERS
STATS_BUFFERS.reset()
cls.debug_enabled = None
cls.layers_initialized.clear()
@classmethod
def get_layer_count(cls):
"""
Layer counter is used when layer names are not provided to modules by the user.
"""
lc = cls.layer_count
cls.layer_count += 1
return lc
@classmethod
def set_weight_tensor_tp_group_reduce(cls, enabled):
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
TEDebugState.set_weight_tensor_tp_group_reduce(enabled)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils functions for the debug module."""
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
return any(q.any_feature_enabled() for q in quantizers)
......@@ -83,7 +83,8 @@ _load_library()
from . import flax
from . import quantize
from .quantize import fp8_autocast
from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType
......@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper(
)
__all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource",
"MajorShardingType",
"ShardingResource",
"ShardingType",
"flax",
"praxis",
"quantize",
]
......@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
......
......@@ -10,6 +10,7 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
import transformer_engine_jax
......@@ -26,12 +27,12 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
)
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias
from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
)
......@@ -110,41 +111,31 @@ class ActLuPrimitive(BasePrimitive):
"""
te_act_lu_p abstract
"""
del act_enum, act_len, scale_shapes
del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
out_shape = (
*x_aval.shape[:-2],
1,
x_aval.shape[-1],
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
)
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer)
if len(rowwise_scale_inv_shape) > 1:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if len(colwise_scale_inv_shape) > 1:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
if not is_2x:
out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
if is_2x:
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
......@@ -172,7 +163,7 @@ class ActLuPrimitive(BasePrimitive):
assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
)
return out
......@@ -211,15 +202,8 @@ class ActLuPrimitive(BasePrimitive):
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if is_2x:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
# Slice out padding for MXFP8, noop for DelayedScaling
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
......@@ -227,6 +211,7 @@ class ActLuPrimitive(BasePrimitive):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod
......@@ -292,11 +277,14 @@ class ActLuPrimitive(BasePrimitive):
is_outer,
) # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-2], None, x_spec[-2])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else:
colwise_out_spec = out_spec
else:
......@@ -304,18 +292,24 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
return (
out_sharding,
colwise_out_sharding,
......@@ -340,14 +334,14 @@ class ActLuPrimitive(BasePrimitive):
):
del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-1], x_spec[-1])
if act_len == 2 and x_spec[-1] is None:
# Ensure last axis is partitioned and not the gating axis
out_spec = (*x_spec[:-2], None, x_spec[-2])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else:
colwise_out_spec = out_spec
else:
......@@ -355,21 +349,25 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec))
arg_shardings = tuple(arg_shardings)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -394,7 +392,7 @@ class ActLuPrimitive(BasePrimitive):
)
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -409,10 +407,59 @@ class ActLuPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else:
colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor.
amax = ("l",)
return SdyShardingRule(
(
x_axes,
"…1",
),
(out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
)
register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
DActLu DBias Cast Transpose Primitive
......@@ -445,42 +492,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p abstract
"""
del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32
ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size
out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x:
# Don't transpose output for MXFP8
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
t_shape = out_shape
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else:
t_shape = multidim_transpose(out_shape)
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias:
dbias_shape = gi_hidden_size
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
dbias_shape = (act_len, ir_hidden_size)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
......@@ -489,9 +535,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return (
out_aval,
......@@ -543,7 +594,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
dz,
x,
scale,
scaling_mode=scaling_mode,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
is_dbias=is_dbias,
act_enum=int(act_enum),
......@@ -587,23 +638,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
# Slice out padding for MXFP8, noop for DelayedScaling
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
def batcher(
......@@ -670,15 +714,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, is_dbias, act_len, is_outer
del scale_dtype, scale_shapes, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else:
colwise_x_spec = x_spec
else:
......@@ -687,23 +732,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
)
return (
out_sharding,
......@@ -711,7 +765,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
dbias_sharding,
)
@staticmethod
......@@ -731,10 +785,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out")
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else:
colwise_x_spec = x_spec
else:
......@@ -743,38 +802,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = (
arg_shardings[1],
arg_shardings[1],
*arg_shardings[2:],
) # dz and x are the same
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
dbias_sharding,
)
def sharded_impl(dz, x, scale):
......@@ -799,7 +859,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else:
global_dbias = local_dbias
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -808,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else:
colwise_out = tuple(x_axes)
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",)
amax = ("…4",)
return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DActLuDBiasQuantizePrimitive)
......@@ -816,14 +916,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-1)
act_len = len(activation_type)
assert inputs.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {inputs.shape} and act_len {act_len}"
)
x = jnp.split(inputs, act_len, axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
if quantizer:
return quantizer.quantize(x)
return quantizer.quantize(x, flatten_axis=-1)
return x
......@@ -837,6 +944,12 @@ def _jax_quantize_dact_dbias(
"""
JAX implementation of dact_lu and dbias with optional quantization
"""
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
_, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
)
......@@ -844,10 +957,10 @@ def _jax_quantize_dact_dbias(
dbias = None
if is_dbias:
dbias = _jax_dbias(dx).astype(x.dtype)
dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
if quantizer is not None:
dx = quantizer.quantize(dx, dq_dtype=x.dtype)
dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else:
dx = dx.astype(x.dtype)
......@@ -863,6 +976,7 @@ def act_lu(
Args:
x: Input tensor to be processed.
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
......@@ -873,12 +987,17 @@ def act_lu(
A ScaledTensor containing the quantized activated input.
"""
act_type_id = ActivationEnum[activation_type].value
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support 2x quantization for DelayedScaling yet
......@@ -889,17 +1008,16 @@ def act_lu(
return war_output
scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type))
output_shape = (*x.shape[:-2], x.shape[-1])
if quantizer is None:
x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=x.dtype,
act_enum=act_type_id,
act_len=len(activation_type),
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
act_len=act_len,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
......@@ -911,7 +1029,6 @@ def act_lu(
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
(
rowwise_casted_output,
colwise_casted_output,
......@@ -923,25 +1040,15 @@ def act_lu(
scale,
out_dtype=quantizer.q_dtype,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(output_shape),
# output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True,
)
rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
if len(rowwise_scale_inv.shape) > 1:
rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis
if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
colwise_output_shape = output_shape
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_output_shape = multidim_transpose(output_shape)
colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
if len(colwise_scale_inv.shape) > 1:
colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis
quantizer.update(updated_amax)
return ScaledTensorFactory.create(
......@@ -951,8 +1058,8 @@ def act_lu(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
......@@ -968,7 +1075,7 @@ def quantize_dact_dbias(
Args:
dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
......@@ -979,21 +1086,25 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the bias.
"""
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not DActLuDBiasQuantizePrimitive.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = quantize_dact_dbias(
dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None
)
return quantize_dbias(out, is_dbias=True, quantizer=quantizer)
out = dact_lu(dz, x, activation_type, quantizer=None)
return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
is_gated = len(activation_type) == 2
is_gated = act_len == 2
# TE/common does not support DelayedScaling2x for gated-act yet
if is_gated:
war_output = try_apply_delayed_scaling_2x_war(
......@@ -1003,6 +1114,7 @@ def quantize_dact_dbias(
activation_type=activation_type,
is_dbias=is_dbias,
quantizer=quantizer,
flatten_axis=-2,
)
if war_output is not None:
return war_output
......@@ -1019,18 +1131,18 @@ def quantize_dact_dbias(
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
is_dbias=False,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output).astype(x.dtype)
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if isinstance(quantizer, DelayedScaleQuantizer):
......@@ -1041,16 +1153,9 @@ def quantize_dact_dbias(
dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
)
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype)
else:
out, dbias = quantize_dbias(
dgated,
quantizer=quantizer,
is_dbias=True,
dq_dtype=x.dtype,
)
out, dbias = _quantize_dbias_impl(
dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
)
return out, dbias
out_shape = x.shape
......@@ -1070,15 +1175,16 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(out_shape),
# output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
is_outer=True,
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......@@ -1090,8 +1196,9 @@ def quantize_dact_dbias(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=-2, # as output has act axis
)
return out, dbias
......
......@@ -14,6 +14,7 @@ import jax
import jax.numpy as jnp
from jax import dtypes, lax
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
......@@ -42,6 +43,7 @@ from ..sharding import (
get_mesh_axis_rank,
get_all_mesh_axes,
num_of_devices,
with_sharding_constraint,
)
......@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del mesh, result_types
# Keep in sync with `infer_sharding_from_operands`.
# We only need the first input. Fill up the rest with placeholders.
input_spec = [(f"…{x}",) for x in range(len(value_types))]
# The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
# instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
rng_sharding = (f"…{len(value_types)}",)
if config.qkv_layout.is_qkvpacked():
input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
input_spec[0] = ("…0", "seqlen", "head", "hidden")
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
out_sharding = ("…0", "seqlen", "head", "hidden")
if is_packed_softmax:
softmax_aux_sharding = ("…0", "seqlen", "head", "i")
else:
softmax_aux_sharding = ("…0", "head", "seqlen", "i")
return SdyShardingRule(
tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
)
register_primitive(FusedAttnFwdPrimitive)
......@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
return SdyShardingRule(input_spec, output_spec)
register_primitive(FusedAttnBwdPrimitive)
......@@ -2436,13 +2476,15 @@ def fused_attn_fwd(
primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
return primitive.bind(
output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive,
bias,
seed,
*seq_desc_flatten,
config=fused_config,
)
rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
return (output, softmax_aux, rng_state)
def fused_attn_bwd(
......
......@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta):
"""
return NotImplemented
@staticmethod
@abstractmethod
def shardy_sharding_rule(*args):
"""
Returns the sharding rule for this primitive.
"""
del args
return "... -> ..."
def register_primitive(cls):
"""
......@@ -123,7 +132,9 @@ def register_primitive(cls):
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(
infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition
infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition,
sharding_rule=cls.shardy_sharding_rule,
)
mlir.register_lowering(
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
......
......@@ -6,9 +6,9 @@
from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce
import operator
from transformer_engine_jax import get_device_compute_capability
import jax
import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive
......@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9)
impl_static_args = ()
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs_contig_aval,
lhs_scale_contig_aval,
rhs_contig_aval,
rhs_scale_contig_aval,
bias_contig_aval,
dim_list_aval,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
):
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval
del bias_contig_aval, dim_list_aval
del num_gemms, scaling_mode
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype)
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8)
return (out_flat_aval, wkspace_aval)
def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
"""
Args:
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
args[ 0 : num_gemms] are the lhs tensors,
args[ num_gemms : 2*num_gemms] are the rhs tensors,
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
num_gemms: Number of GEMM operations to perform.
scaling_mode: Scaling mode for the GEMM operations.
out_dtype: Data type of the output tensors.
has_bias: Boolean indicating if bias tensors are provided.
Returns:
A tuple of ShapedArray objects of size num_gemms+1:
ret[0 : num_gemms]: GEMM output tensors,
ret[num_gemms]:workspace tensor.
"""
del scaling_mode
expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
......@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive):
return out_aval
@staticmethod
def lowering(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*args,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
has_bias=has_bias,
)
@staticmethod
def impl(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*args,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
has_bias=has_bias,
)
return out[0] # out is [out_flat, wkspace], only return out_flat
return out[:-1] # out is [out_list, wkspace], only return out_list
register_primitive(GroupedGemmPrimitive)
......@@ -183,10 +163,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T")
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
......@@ -199,13 +178,13 @@ def _jax_gemm_delayed_scaling_fp8(
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T":
if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T":
if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
......@@ -231,7 +210,7 @@ def _jax_gemm_mxfp8_1d(
JAX GEMM for MXFP8 via scaled_matmul
"""
assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
......@@ -292,10 +271,10 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
......@@ -367,6 +346,7 @@ def swizzled_scale(scales):
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
scales = scales.reshape(rows, cols)
return scales
......@@ -381,18 +361,12 @@ def grouped_gemm(
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list)
lhs_list_ = []
rhs_list_ = []
lhs_sinv_list_ = []
rhs_sinv_list_ = []
bias_list_ = []
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
......@@ -403,20 +377,20 @@ def grouped_gemm(
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T":
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T":
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
......@@ -428,24 +402,25 @@ def grouped_gemm(
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# already_transposed doesn't matter for the output shape
# Note: already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
......@@ -456,61 +431,37 @@ def grouped_gemm(
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,)))
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})"
k = kl
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0):
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples"
" of 16"
)
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NVTE_NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(f"m = {bm}, n = {bn}, k = {kl}; ")
print("cuBLAS requires the problem shapes being multiples of 16")
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
lhs_list_.append(lhs_3d)
rhs_list_.append(rhs_3d)
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_sinv_list_.append(lhs_scale_inv)
rhs_sinv_list_.append(rhs_scale_inv)
if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1))
out_flat_size += bm * bn
out_offsets.append(out_flat_size)
lhs_contig = jnp.concatenate(lhs_contig_)
rhs_contig = jnp.concatenate(rhs_contig_)
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_)
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_)
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
bias_list_.append(bias_list[i])
out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
*rhs_list_,
*lhs_sinv_list_,
*rhs_sinv_list_,
*bias_list_,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
has_bias=1 if bias_list is not None else 0,
)
# Split the output back into tensors
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
return out_list
......@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType
......@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1):
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
"""
te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose.
transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary
transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis_boundary == 2
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 2
static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 3
static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2)
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose.
transpose_start_idx = static_axis_boundary + 1
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape))
assert transpose_start_idx < transpose_axis_boundary
transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
assert transpose_start_idx < transpose_axis
return (
*shape[:transpose_start_idx],
*shape[transpose_axis_boundary:],
*shape[transpose_start_idx:transpose_axis_boundary],
*shape[transpose_axis:],
*shape[transpose_start_idx:transpose_axis],
)
......@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
return (
quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE
and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100
and is_dbias
)
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
......@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING
and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
)
if not should_apply_war:
......@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE
quantizer.q_layout = QuantizeLayout.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None
if isinstance(rowwise, tuple):
other_outputs = rowwise[1:]
rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1)))
quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
colwise_data = jnp.transpose(
rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
)
output_2x = ScaledTensorFactory.create(
data=rowwise.data,
scale_inv=rowwise.scale_inv,
......@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
layout=quantizer.get_layout(),
q_layout=QuantizeLayout.ROWWISE_COLWISE,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
)
if other_outputs is not None:
return (output_2x,) + other_outputs
......
......@@ -12,6 +12,7 @@ from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec
......@@ -30,7 +31,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
)
......@@ -63,6 +64,27 @@ def get_backward_sm_margin():
return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
@cache
def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether CuDNN norm fwd is enabled."""
# MXFP8_1D_SCALING always uses CuDNN currently
return (
int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1
or scaling_mode == ScalingMode.MXFP8_1D_SCALING
)
@cache
def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether norm should compute `gamma += 1.0` for zero-centered gamma
in weight dtype as opposed to compute dtype."""
if not is_norm_fwd_cudnn_enabled(scaling_mode):
# If CuDNN is not enabled, we use the TE backend which uses the compute dtype not weight dtype
# Remove this when TE supports gamma += 1.0 in weight dtype
return False
return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1
class NormFwdPrimitive(BasePrimitive):
"""
Layer Normalization Forward FP8 Primitive
......@@ -105,6 +127,26 @@ class NormFwdPrimitive(BasePrimitive):
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size
......@@ -112,33 +154,13 @@ class NormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype
jax_dtype_to_te_dtype(out_dtype),
norm_type,
scaling_mode.value,
scaling_mode,
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
is_2x,
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x_aval.shape, is_padded=not is_outer
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_aval = jax.core.ShapedArray(
shape=x_aval.shape if is_2x else (1,), dtype=out_dtype
)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = x_aval.update(
wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
......@@ -274,17 +296,17 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
# slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv = scale_inv.flatten()[
: reduce(operator.mul, rowwise_scale_inv_shape)
].reshape(rowwise_scale_inv_shape)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape)
].reshape(colwise_scale_inv_shape)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape)
return (
out,
colwise_out,
......@@ -364,6 +386,8 @@ class NormFwdPrimitive(BasePrimitive):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
......@@ -371,34 +395,27 @@ class NormFwdPrimitive(BasePrimitive):
"and hurt performance."
)
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out"
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
output = (
out_sharding,
colwise_out_sharding,
......@@ -427,8 +444,11 @@ class NormFwdPrimitive(BasePrimitive):
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None:
warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
......@@ -445,43 +465,30 @@ class NormFwdPrimitive(BasePrimitive):
f"{NormFwdPrimitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x"
)
g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma")
b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta")
out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out")
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]),
desc="NormFwdPrimitive.rsigma",
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
)
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu")
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
scale_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale"
)
scale_inv_sharding = scale_sharding.duplicate_with_new_description(
"NormFwdPrimitive.scale_inv"
scale_inv_spec = amax_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -517,7 +524,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes,
is_outer=True,
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -534,6 +541,57 @@ class NormFwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del (
zero_centered_gamma,
epsilon,
out_dtype,
scale_dtype,
scale_shapes,
is_outer,
mesh,
result_types,
)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1
)
x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",)
colwise_out = out if is_2x else ("…4",)
rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",)
return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
amax,
mu,
rsigma,
),
**scale_rules.factor_sizes,
)
register_primitive(NormFwdPrimitive)
......@@ -737,6 +795,11 @@ class NormBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l"
register_primitive(NormBwdPrimitive)
......@@ -746,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
JAX native layernorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
if not is_norm_zero_centered_gamma_in_weight_dtype(
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
):
gamma = gamma.astype(jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon)
......@@ -767,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
JAX native rmsnorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
if not is_norm_zero_centered_gamma_in_weight_dtype(
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
):
gamma = gamma.astype(jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon)
normed_input = x_ * rsigma
......@@ -816,7 +887,7 @@ def layernorm_fwd(
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -824,7 +895,6 @@ def layernorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32)
)
if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x,
......@@ -835,7 +905,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)),
......@@ -845,7 +915,7 @@ def layernorm_fwd(
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
......@@ -864,7 +934,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
......@@ -873,7 +943,7 @@ def layernorm_fwd(
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -882,7 +952,7 @@ def layernorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
......@@ -900,8 +970,8 @@ def layernorm_fwd(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
return scaled_tensor, mu, rsigma
......@@ -997,7 +1067,7 @@ def rmsnorm_fwd(
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -1017,7 +1087,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
......@@ -1027,7 +1097,7 @@ def rmsnorm_fwd(
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False
(
rowwise_casted_output,
......@@ -1046,7 +1116,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
......@@ -1055,7 +1125,7 @@ def rmsnorm_fwd(
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1064,7 +1134,7 @@ def rmsnorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False
)
......@@ -1082,8 +1152,8 @@ def rmsnorm_fwd(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
return scaled_tensor, rsigma
......
......@@ -2,12 +2,15 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
import operator
from functools import reduce
from typing import Tuple, Optional
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
import transformer_engine_jax
......@@ -24,7 +27,7 @@ from .misc import (
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -50,7 +53,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
6,
7,
8,
) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
......@@ -61,7 +65,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -73,49 +78,56 @@ class DBiasQuantizePrimitive(BasePrimitive):
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
assert scale_aval is None or scale_aval.dtype == jnp.float32
rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
t_shape = multidim_transpose(x_aval.shape)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
# Don't transpose output for MXFP8
t_shape = x_aval.shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias:
gi_hidden_size = x_aval.shape[-1]
dbias_shape = (gi_hidden_size,)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
dbias_shape = x_aval.shape[flatten_axis:]
gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
(wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
scaling_mode,
QuantizeLayout(
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return (
rowwise_out_aval,
......@@ -151,7 +163,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -168,8 +181,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
ctx,
x,
scale,
scaling_mode=scaling_mode,
q_axis=q_axis,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
is_dbias=is_dbias,
)
......@@ -179,7 +193,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -203,7 +218,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
......@@ -211,16 +227,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
......@@ -237,7 +251,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -260,7 +275,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
......@@ -272,7 +288,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def infer_sharding_from_operands(
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -281,16 +298,17 @@ class DBiasQuantizePrimitive(BasePrimitive):
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused.
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
else:
......@@ -300,26 +318,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
)
dbias_sharding = NamedSharding(
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
)
return (
out_sharding,
colwise_out_sharding,
......@@ -333,7 +360,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def partition(
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -344,14 +372,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
else:
......@@ -361,26 +390,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
)
dbias_sharding = NamedSharding(
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
......@@ -404,14 +442,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=True,
)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else:
global_updated_amax = local_amax
......@@ -432,53 +471,91 @@ class DBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
)
x_axes = scale_rules.input_spec
colwise_scale_inv = scale_rules.colwise_rule
out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
amax = ("m",)
return SdyShardingRule(
(x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DBiasQuantizePrimitive)
def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None):
def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
if quantizer is None:
return x
return quantizer.quantize(x, dq_dtype=dq_dtype)
return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray):
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0
dtype = dtype or dx.dtype
dbias = jnp.sum(
dx,
axis=tuple(range(dx.ndim - 1)),
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim + flatten_axis)),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias
return dbias.astype(dtype)
def _jax_quantize_dbias(
x,
quantizer: Quantizer = None,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
):
if quantizer is None:
return x, None
return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x)
def _jax_dbias(
dx: jnp.ndarray,
):
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias.astype(dx.dtype)
def _quantize_impl(
def _quantize_dbias_impl(
x: jnp.ndarray,
quantizer: Quantizer,
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -488,40 +565,51 @@ def _quantize_impl(
quantizer is not None
), "quantizer must be provided if dq_dtype is provided"
dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
# TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_impl(
out, _ = _quantize_dbias_impl(
x=x,
is_dbias=False,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
dbias = _jax_dbias(x)
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
if quantizer is None:
if is_dbias:
return x, _jax_dbias(x)
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if isinstance(quantizer, DelayedScaleQuantizer):
......@@ -539,14 +627,15 @@ def _quantize_impl(
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_axis=quantizer.q_axis.value,
q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias,
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......@@ -557,18 +646,18 @@ def _quantize_impl(
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype if dq_dtype is not None else x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
dq_dtype=dq_dtype,
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
)
return out, dbias
return out, dbias.astype(dq_dtype)
# TODO(Phuong): do not expose dq_dtype to users
def quantize(
x: jnp.ndarray,
quantizer: Quantizer,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -576,26 +665,25 @@ def quantize(
x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
Returns:
A ScaledTensor containing the quantized input tensor.
"""
out, _ = _quantize_impl(
out, _ = _quantize_dbias_impl(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
return out
# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias(
dz: jnp.ndarray,
quantizer: Quantizer,
is_dbias: bool = True,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -604,8 +692,8 @@ def quantize_dbias(
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
Returns:
A tuple containing:
......@@ -614,9 +702,6 @@ def quantize_dbias(
- The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False.
"""
return _quantize_impl(
dz,
quantizer=quantizer,
is_dbias=is_dbias,
dq_dtype=dq_dtype,
return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
)
......@@ -31,6 +31,9 @@ __all__ = [
"scaled_upper_triang_masked_softmax_fwd",
"scaled_upper_triang_masked_softmax_bwd",
"is_softmax_kernel_available",
"jax_scaled_softmax",
"jax_scaled_masked_softmax",
"jax_scaled_upper_triang_masked_softmax",
]
......@@ -330,6 +333,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledSoftmaxFwdPrimitive)
......@@ -400,6 +408,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledSoftmaxBwdPrimitive)
......@@ -412,7 +425,7 @@ def scaled_softmax_bwd(
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
_, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits)
return vjp_func(dz)[0]
return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
......@@ -525,6 +538,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...1, ...2 -> ...1"
register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
......@@ -596,6 +614,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
......@@ -682,6 +705,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
result_infos,
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
......@@ -761,15 +789,26 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
result_infos,
)
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled softmax
"""
return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and masked softmax
"""
if mask is not None:
logits += jax.lax.select(
mask > 0,
......@@ -779,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac
return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
......@@ -795,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
Return FP16/BF16 tensor
"""
if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor)
return jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
......@@ -807,7 +849,7 @@ def scaled_masked_softmax_fwd(
Return FP16/BF16 tensor
"""
if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor)
return jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor
)
......@@ -826,7 +868,7 @@ def scaled_masked_softmax_bwd(
"""
if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
)
return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
......@@ -840,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
Return FP16/BF16 tensor
"""
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor
)
......@@ -855,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd(
"""
if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
)
return vjp_func(dz)[0]
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment