Unverified Commit 7948779c authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] GroupedQuantizer and GroupedScaledTensor (#1666)



* refactor the multi_stream utils + implement nvte_multi_tensor_quantize in TE/Common

* implement GroupedQuantizer and grouped_quantize in jaxx

* fix logical_axes_names for transpose tensor in ScaledTensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarMing Huang <mingh@nvidia.com>
parent 9985b02c
......@@ -8,6 +8,7 @@ import numpy as np
import pytest
from jax import jit, value_and_grad
from functools import reduce
from typing import Union
import operator
from utils import (
......@@ -33,6 +34,9 @@ from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
......@@ -41,7 +45,6 @@ from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
GEMM_CASES = [
(256, 256, 512),
......@@ -113,6 +116,38 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
pytest.fail("a must be a ScaledTensor object")
def assert_dequantized_grouped_scaled_tensor(
a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
if isinstance(a, GroupedScaledTensor1x):
assert a.group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
dq_a = a.dequantize()
for dq_a_i, b_i in zip(dq_a, b):
if len(dq_a_i) == 0:
continue
if a.data_layout == "T":
data_ndim = len(a.original_shape)
flatten_axis = a.flatten_axis
if b_i.shape[0] == 1:
b_i = jnp.transpose(
b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
)
else:
b_i = jnp.transpose(
b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
)
dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
else:
pytest.fail("a must be a GroupedScaledTensor object")
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
("gelu",),
......@@ -602,6 +637,57 @@ class TestQuantize:
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
class TestGroupedQuantize:
def test_grouped_qdq(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis, with_group_sizes
):
n_groups, m, n = input_shape
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
# *32 so that the input shapes works for MXFP8
input_shape = (m * 32, n)
if with_group_sizes:
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
assert group_sizes.sum() == m
assert jnp.any(group_sizes == 0) # make sure that at least one group has 0 row
group_sizes = group_sizes * 32
else:
group_sizes = None
input_shape = (n_groups, input_shape[0] // n_groups, input_shape[1])
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
x = jax.random.uniform(subkeys[1], input_shape, in_dtype)
grouped_quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=q_dtype,
q_layout=q_layout,
n_groups=n_groups,
)
scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
)
assert_dequantized_grouped_scaled_tensor(scaled_tensor, x)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:
......
......@@ -92,6 +92,7 @@ list(APPEND transformer_engine_SOURCES
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
......
......@@ -8,6 +8,7 @@
#include <cublas_v2.h>
#include <cuda.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
......@@ -16,6 +17,7 @@
#include "../common.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
namespace {
......@@ -568,18 +570,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
}
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];
// Warning: only call once per device!
static void init_streams_and_events() {
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
}
}
} // namespace transformer_engine
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
......@@ -641,29 +631,31 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;
// Inits streams and events (once, globally)
std::call_once(init_flag, init_streams_and_events);
int num_streams = nvte_get_num_compute_streams();
int num_stream_used = std::min(num_streams, num_gemms);
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
}
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]);
detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s]));
NVTE_CHECK_CUDA(
cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}
......
......@@ -259,6 +259,17 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
*/
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts multiple input tensors to quantized output tensors.
*
* \param[in] inputs List of input tensors to be cast.
* \param[in,out] outputs List of output quantized tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_config, const size_t num_tensors,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -117,8 +117,6 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVT
*/
namespace transformer_engine {
constexpr int num_streams = 4;
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_stream.h
* \brief Functions for multi streams executions.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H
#define TRANSFORMER_ENGINE_MULTI_STREAM_H
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Number of CUDA streams to use in multi-stream operations */
int nvte_get_num_compute_streams();
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_STREAM_H
......@@ -8,13 +8,16 @@
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
......@@ -156,3 +159,43 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str
using namespace transformer_engine;
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
}
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_configs,
const size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
const size_t num_streams = nvte_get_num_compute_streams();
int num_stream_used = std::min(num_streams, num_tensors);
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
}
for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(
cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#include "multi_stream.h"
#include <transformer_engine/multi_stream.h>
#include <mutex>
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace transformer_engine::detail {
cudaStream_t get_compute_stream(int idx) {
const size_t num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx,
", but there are ", num_streams, " streams)");
static std::vector<cudaStream_t> streams(num_streams);
static std::once_flag stream_init_flag;
auto init = [&]() {
for (size_t i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1));
}
};
std::call_once(stream_init_flag, init);
return streams[idx];
}
cudaEvent_t get_compute_stream_event(int idx) {
const size_t num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx,
", but there are ", num_streams, " streams)");
static std::vector<cudaEvent_t> events(num_streams);
static std::once_flag event_init_flag;
auto init = [&]() {
for (size_t i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaEventCreate(&events[i]));
}
};
std::call_once(event_init_flag, init);
return events[idx];
}
int get_num_compute_streams() {
static constexpr int num_compute_streams = 4;
return num_compute_streams;
}
} // namespace transformer_engine::detail
int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); }
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
namespace transformer_engine::detail {
int get_num_compute_streams();
cudaStream_t get_compute_stream(int idx);
cudaEvent_t get_compute_stream_event(int idx);
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
......@@ -5,6 +5,7 @@
import operator
from functools import reduce
from typing import Tuple, Optional
import math
from packaging import version
import jax
......@@ -27,9 +28,13 @@ from .misc import (
NamedSharding,
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import (
ScaledTensor2x,
ScaledTensor,
ScaledTensorFactory,
GroupedScaledTensor1x,
Quantizer,
GroupedQuantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
......@@ -42,7 +47,7 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias"]
__all__ = ["quantize", "quantize_dbias", "grouped_quantize"]
class BaseDBiasQuantizePrimitive(BasePrimitive):
......@@ -740,3 +745,290 @@ def quantize_dbias(
return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
)
class GroupedQuantizePrimitive(BasePrimitive):
"""
Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
"""
name = "te_grouped_quantize_ffi"
multiple_results = True
impl_static_args = (
3,
4,
5,
6,
7,
8,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
scale_aval,
group_sizes_aval,
*,
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
group_axis,
scale_dtype,
):
"""
te_dbias_quantize_p abstract
"""
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = math.prod(x_aval.shape)
# TODO(Phuong): can scale_aval be None?
assert scale_aval is None or scale_aval.dtype == jnp.float32
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_grouped_scale_shape_2x(
x_aval.shape,
group_sizes_aval.size,
group_axis,
is_padded=True,
flatten_axis=flatten_axis,
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
rowwise_scale_inv_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_aval = jax.core.ShapedArray(
shape=rowwise_scale_inv_shape, dtype=scale_dtype
)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return (
rowwise_out_aval,
colwise_out_aval,
rowwise_scale_inv_aval,
colwise_scale_inv_aval,
amax_aval,
)
@staticmethod
def outer_abstract(*args, **kwargs):
"""
te_dbias_quantize_p outer primitive abstract
"""
# Phuong: keeping outer abstract so that we can add fuse dbias later
(
rowwise_out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
_dbias,
_wkspace,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod
def lowering(
ctx,
x,
scale,
group_sizes,
*,
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
group_axis,
scale_dtype,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype
x_aval, scale_aval, group_sizes_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
assert group_sizes_aval.dtype == jnp.int32
assert group_axis == 0
return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
ctx,
x,
scale,
group_sizes,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
)
@staticmethod
def impl(
x,
scale,
group_sizes,
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
group_axis,
scale_dtype,
):
"""
te_dbias_quantize_p implementation
"""
assert GroupedQuantizePrimitive.inner_primitive is not None
(
rowwise_out,
colwise_out,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
) = GroupedQuantizePrimitive.inner_primitive.bind(
x,
scale,
group_sizes,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
flatten_axis=flatten_axis,
group_axis=group_axis,
scale_dtype=scale_dtype,
)
return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)
register_primitive(GroupedQuantizePrimitive)
def grouped_quantize(
x: jnp.ndarray,
quantizer: GroupedQuantizer,
group_sizes: jnp.ndarray = None,
flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
"""Quantize a tensor in grouped manner.
This function quantizes a tensor by splitting it into groups along a specified axis
and applying quantization to each group separately. The groups can be either specified
explicitly through group_sizes or automatically split along the group_axis.
Args:
x: Input tensor to quantize
quantizer: The quantizer to use for quantization
group_sizes: Array of ints containing the size of each group (default: None)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
A GroupedScaledTensor1x containing the quantized data
Note:
- If group_sizes is not provided, the tensor will be split into equal-sized groups
along the group_axis
- The group_axis is currently fixed to 0
- The quantizer's q_layout determines whether row-wise, column-wise, or both
quantization is applied
"""
if quantizer is None:
return x
# TODO(Phuong): add support for flatten_axis = -2
assert flatten_axis in (
-1,
x.ndim - 1,
), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
group_axis = 0
if group_sizes is None:
group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)
if not GroupedQuantizePrimitive.enabled():
return quantizer.quantize(
x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
)
n_groups = group_sizes.size
original_shape = x.shape
assert n_groups == len(
quantizer.quantizers
), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
scale = jnp.empty((n_groups,), jnp.float32)
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
for i, quantizer_i in enumerate(quantizer.quantizers):
scale = scale.at[i].set(quantizer_i.scale[0])
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
scale = scale.at[i].set(tmp_scale[0])
is_tensor_scaling = quantizer.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
)
# WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
# So we performance ROWWISE_COLWISE and use the colwise_tensor_output
apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
(
rowwise_casted_output,
colwise_casted_output,
rowwise_scale_inv,
colwise_scale_inv,
updated_amax,
) = GroupedQuantizePrimitive.outer_primitive.bind(
x,
scale,
group_sizes,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
flatten_axis=flatten_axis,
group_axis=group_axis,
scale_dtype=quantizer.get_scale_dtype(),
)
# For DelayedScaling2x and CurrentScaling2x, the scale buffer
# is shared between rowwise and colwise
if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
colwise_scale_inv = rowwise_scale_inv
# TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
for i, quantizer_i in enumerate(quantizer.quantizers):
quantizer_i.update(updated_amax[i].reshape((1,)))
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
return out
......@@ -68,6 +68,8 @@ pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_s
// Quantization
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
......
......@@ -10,6 +10,8 @@
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......@@ -169,6 +171,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto num_streams = nvte_get_num_compute_streams();
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) {
......
......@@ -26,5 +26,19 @@ std::vector<size_t> Shape::to_vector() const {
return shape;
}
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) {
auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x;
auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y;
auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x;
auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y;
NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M);
NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N);
size_t scale_x = DIVUP((M / block_x), alignment_x) * alignment_x;
size_t scale_y = DIVUP((N / block_y), alignment_y) * alignment_y;
return {scale_x, scale_y};
}
} // namespace jax
} // namespace transformer_engine
......@@ -67,5 +67,16 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
}
}
constexpr struct BlockSize {
size_t x;
size_t y;
} MXFP8_BLOCK_SIZE{1, 32};
constexpr struct Alignment {
size_t x;
size_t y;
} MXFP8_ALIGNMENT{128, 4};
std::vector<size_t> get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise);
} // namespace jax
} // namespace transformer_engine
......@@ -25,6 +25,7 @@ pybind11::dict Registrations() {
// Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
......
......@@ -8,6 +8,7 @@
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h"
#include "transformer_engine/transformer_engine.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......@@ -226,5 +227,182 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
.Ret<Buffer_Type>(), // output
FFI_CudaGraph_Traits);
Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales,
Buffer_Type group_sizes, Result_Type outputs,
Result_Type colwise_outputs, Result_Type scale_invs,
Result_Type colwise_scale_invs, Result_Type amaxs,
JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
int64_t flatten_axis) {
NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(outputs->element_type());
NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scales.element_type());
auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
auto *output_ptr = reinterpret_cast<uint8_t *>(outputs->untyped_data());
auto *colwise_output_ptr = reinterpret_cast<uint8_t *>(colwise_outputs->untyped_data());
auto *sinv_ptr = reinterpret_cast<uint8_t *>(scale_invs->untyped_data());
auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());
bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
size_t input_dtype_bytes = te_dtype_bytes(in_dtype);
size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0;
size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0;
size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;
auto input_dims = inputs.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto m = product(input_dims, 0, flatten_axis);
auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m * n};
// These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> input_holders;
std::vector<TensorWrapper> output_holders;
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std::vector<NVTETensor> input_list;
std::vector<NVTETensor> output_list;
size_t num_groups = group_sizes.dimensions()[0];
size_t dim_list_bytes = group_size_dtype_bytes * num_groups;
std::vector<int32_t> dim_list_host(num_groups);
auto *group_size_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes,
"Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m,
input_dims[0]);
if (is_delayed_scaling) {
NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups,
", got ", amaxs->dimensions()[0]);
NVTE_CHECK(amax_dtype == DType::kFloat32 && scale_dtype == DType::kFloat32);
cudaMemsetAsync(amax_ptr, 0, sizeof(float) * num_groups, stream);
}
size_t sinv_size = 0;
size_t colwise_sinv_size = 0;
size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1;
size_t num_non_empty_groups = 0;
for (size_t i = 0; i < num_groups; i++) {
size_t m_i = dim_list_host[i] * non_group_m;
// Skip for zero-size input + shiff the scale ptr
if (m_i == 0) {
if (is_tensor_scaling) scale_ptr += scale_dtype_bytes;
continue;
}
num_non_empty_groups++;
auto shape_i = std::vector<size_t>{m_i, n};
auto shape_trans_i = std::vector<size_t>{n, m_i};
auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
if (has_rowwise) {
out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);
if (is_fp8_dtype(out_dtype)) {
if (is_tensor_scaling) {
out_i.set_scale(static_cast<void *>(scale_ptr), DType::kFloat32, std::vector<size_t>{1});
out_i.set_amax(static_cast<void *>(amax_ptr), DType::kFloat32, std::vector<size_t>{1});
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype,
std::vector<size_t>{1});
sinv_size = 1;
} else {
const bool is_colwise = false;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
sinv_size = product(sinv_shape_i);
}
}
}
if (has_colwise) {
auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_sinv_ptr = is_tensor_scaling ? sinv_ptr : colwise_sinv_ptr;
if (is_tensor_scaling) {
out_i.set_columnwise_scale_inv(static_cast<void *>(tmp_sinv_ptr), sinv_dtype,
std::vector<size_t>{1});
colwise_sinv_size = 1;
} else {
const bool is_colwise = true;
auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
sinv_shape_i);
colwise_sinv_size = product(sinv_shape_i);
}
}
input_holders.push_back(std::move(inp_i));
output_holders.push_back(std::move(out_i));
input_list.push_back(input_holders.back().data());
output_list.push_back(output_holders.back().data());
input_ptr += m_i * n * input_dtype_bytes;
scale_ptr += scale_dtype_bytes;
output_ptr += m_i * n * output_dtype_bytes;
colwise_output_ptr += m_i * n * colwise_output_dtype_bytes;
sinv_ptr += sinv_size * sinv_dtype_bytes;
colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes;
amax_ptr += amax_dtype_bytes;
}
QuantizationConfigWrapper quant_config;
nvte_multi_tensor_quantize(input_list.data(), output_list.data(), quant_config,
num_non_empty_groups, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("q_layout")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -7,24 +7,54 @@ Dequantization utilities for TE/JAX.
This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
import math
from dataclasses import dataclass
from abc import ABC, abstractmethod
import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
__all__ = ["Dequantizer"]
__all__ = ["ScalingModeToDequantizerMap"]
@dataclass
class Dequantizer(ABC):
"""
Base Dequantizer Class
"""
@staticmethod
@abstractmethod
def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
pass
class Dequantizer:
"""Encapsulation class for dequantization helpers.
@staticmethod
@abstractmethod
def dequantize(scaled_tensor):
"""Dequantizing given tensor to higher precision."""
class TensorScaleDequantizer(Dequantizer):
"""
TensorScaling Dequantizer Class
This class provides static methods for dequantizing tensors that have been
quantized using different scaling modes. It supports both delayed scaling
and block scaling modes.
quantized using different tensor scaling modes. It supports both delayed scaling
and current scaling modes.
"""
@staticmethod
def _dq_func_tensor_scaling(scaled_tensor):
def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
del kwargs
return jnp.asarray(
data.astype(jnp.float32) * scale_inv.astype(jnp.float32),
dq_dtype,
)
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a tensor using delayed scaling.
This function dequantizes a tensor that was quantized using delayed scaling
......@@ -36,36 +66,48 @@ class Dequantizer:
Returns:
The dequantized tensor in the specified data type
"""
return jnp.asarray(
scaled_tensor.data.astype(jnp.float32) * scaled_tensor.scale_inv.astype(jnp.float32),
scaled_tensor.dq_dtype,
return TensorScaleDequantizer._dequantize_func(
scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype
)
class BlockScaleDequantizer(Dequantizer):
"""BlockScaling Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using block scaling modes.
"""
@staticmethod
def _dq_func_block_scaling(scaled_tensor):
def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis):
"""Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
scaled_tensor: The quantized tensor to dequantize
data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor in the specified data type
The dequantized tensor
"""
data = scaled_tensor.data.astype(jnp.float32)
data = data.astype(jnp.float32)
scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32)
data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
scale_inv = jax.lax.slice(
scale_inv, [0] * len(scale_shape), scale_shape
) # slice out the padding
data = data.reshape(
*data_shape[: flatten_axis - 1],
......@@ -76,25 +118,17 @@ class Dequantizer:
int(data_shape[-1] / scale_shape[-1]),
)
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
)
scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
funcs = {
ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
}
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape)
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a scaled tensor using the appropriate scaling mode.
"""Dequantize a tensor using block scaling.
This method selects the appropriate dequantization function based on the
scaling mode used for quantization and applies it to the tensor.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
scaled_tensor: The quantized tensor to dequantize
......@@ -102,5 +136,86 @@ class Dequantizer:
Returns:
The dequantized tensor in the specified data type
"""
dq_func = Dequantizer.funcs[scaled_tensor.scaling_mode]
return dq_func(scaled_tensor)
return BlockScaleDequantizer._dequantize_func(
scaled_tensor.data,
scaled_tensor.scale_inv,
scaled_tensor.dq_dtype,
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
)
ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
}
@staticmethod
def _grouped_dequantize(grouped_scaled_tensor):
"""Dequantize a grouped tensor.
Args:
grouped_scaled_tensor: The grouped scaled tensor to dequantize
Returns:
List of dequantized tensors for each group
"""
data = grouped_scaled_tensor.data
scale_inv = grouped_scaled_tensor.scale_inv
group_sizes = grouped_scaled_tensor.group_sizes
flatten_axis = grouped_scaled_tensor.flatten_axis
scaling_mode = grouped_scaled_tensor.scaling_mode
original_shape = grouped_scaled_tensor.original_shape
group_axis = grouped_scaled_tensor.group_axis
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
output = []
non_group_shape = tuple(
original_shape[i] for i in range(len(original_shape)) if i != group_axis
)
matrix_sizes = group_sizes * math.prod(non_group_shape)
data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1])
scale_inv_ptr = 0
for i, data_i in enumerate(data):
data_shape_i = (
*original_shape[:group_axis],
group_sizes[i],
*original_shape[group_axis + 1 :],
)
assert math.prod(data_shape_i) == data_i.size, (
f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
f" {data_i.size}"
)
scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
scale_shape_i_size = math.prod(scale_shape_i)
scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size]
dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
if len(data_i) == 0:
out_i = []
else:
out_i = dequantizer_type._dequantize_func(
data_i.reshape(data_shape_i),
scale_inv_i.reshape(scale_shape_i),
grouped_scaled_tensor.dq_dtype,
scaling_mode=grouped_scaled_tensor.scaling_mode,
is_colwise=grouped_scaled_tensor.is_colwise,
flatten_axis=grouped_scaled_tensor.flatten_axis,
)
output.append(out_i)
scale_inv_ptr += scale_shape_i_size
return output
Dequantizer.grouped_dequantize = _grouped_dequantize
......@@ -9,7 +9,8 @@ This module provides classes and utilities for quantizing tensors in JAX.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
from typing import Union, Optional
from typing import Union, Optional, Tuple
import warnings
import jax
import jax.numpy as jnp
......@@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
......@@ -30,6 +31,7 @@ __all__ = [
"CurrentScaleQuantizer",
"DelayedScaleQuantizer",
"BlockScaleQuantizer",
"GroupedQuantizer",
"QuantizerFactory",
"noop_quantizer_set",
"compute_scale_from_amax",
......@@ -74,6 +76,7 @@ class Quantizer(ABC):
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_layout: QuantizeLayout
data_layout: str
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -82,7 +85,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
return (children, aux_data)
@classmethod
......@@ -110,13 +113,22 @@ class Quantizer(ABC):
"""
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod
def get_data_layout(self) -> str:
"""Get the data data_layout.
"""Get the data data_layout string.
Returns:
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return self.data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return self.data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return self.data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
@abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
......@@ -132,7 +144,9 @@ class Quantizer(ABC):
A ScaledTensor1x containing the quantized data
"""
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
def quantize(
self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs
) -> ScaledTensor:
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -145,6 +159,7 @@ class Quantizer(ABC):
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
del kwargs
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(
......@@ -159,7 +174,7 @@ class Quantizer(ABC):
return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, **kwargs):
"""Get shapes for scale tensors.
Args:
......@@ -169,6 +184,7 @@ class Quantizer(ABC):
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
del kwargs
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self):
......@@ -194,24 +210,7 @@ class CurrentScaleQuantizer(Quantizer):
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
data_layout = "NT"
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
data_layout: str = "NT"
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
......@@ -230,16 +229,11 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,)).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scaled_x = x.astype(compute_dtype) * scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / scale
return ScaledTensorFactory.create_1x(
......@@ -295,6 +289,7 @@ class CurrentScaleQuantizer(Quantizer):
data_layout="T",
flatten_axis=flatten_axis,
)
if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
......@@ -332,7 +327,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
return (children, aux_data)
def _quantize_func(
......@@ -447,16 +442,7 @@ class BlockScaleQuantizer(Quantizer):
scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data data_layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
data_layout: str = "NN"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
......@@ -591,6 +577,189 @@ class QuantizerSet:
return cls(*aux_data, *children)
@register_pytree_node_class
@dataclass
class GroupedQuantizer(Quantizer):
"""Quantizer for grouped arrays.
This class extends Quantizer to support quantization of arrays in grouped manner,
where elements are grouped along a specified axis then quantized separately.
Attributes:
data_layout: The data layout specification
n_groups: Number of groups for quantization
quantizers: Tuple of quantizers for each group
"""
data_layout: str = None
n_groups: int = 1
quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,))
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.quantizers,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups)
return (children, aux_data)
def __post_init__(self):
if self.quantizers[0] is None:
self.quantizers = QuantizerFactory.create(
self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
)
self.data_layout = self.quantizers[0].data_layout
def _create_grouped_tensor_from_tensor_list(
self, tensor_list, group_sizes, original_shape, group_axis, mode
):
# mode 0 = concate, mode 1 = add
# TODO(Ming Huang): Consider to apply Enum for mode.
assert mode in [0, 1]
grouped_data = (
[] if mode == 0 else jnp.zeros(tensor_list[0].data.shape, tensor_list[0].data.dtype)
)
grouped_scale_inv = []
for tensor in tensor_list:
if mode == 0:
grouped_data.append(tensor.data.flatten())
else:
grouped_data += tensor.data
grouped_scale_inv.append(tensor.scale_inv.flatten())
grouped_data = jnp.concatenate(grouped_data) if mode == 0 else grouped_data.flatten()
grouped_scale_inv = jnp.concatenate(grouped_scale_inv)
return ScaledTensorFactory.create_1x(
grouped_data,
grouped_scale_inv,
self.scaling_mode,
tensor_list[0].dq_dtype,
tensor_list[0].is_colwise,
tensor_list[0].data_layout,
tensor_list[0].flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
def _quantize_func(self, *args, **kwargs):
pass
def quantize(
self,
x,
is_rowwise: bool = None,
is_colwise: bool = None,
dq_dtype=None,
flatten_axis=-1,
group_sizes=None,
group_axis=0,
):
"""Quantize a tensor in grouped manner.
Expected input shape: [M, K] or [G, K, N]
Split to x.shape[group_axis] number of groups if group_sizes is not given
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
group_sizes: Array of ints containing the size of each group (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
assert group_axis == 0, "Only group_axis == 0 is supported now!"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
assert is_rowwise or is_colwise, "No quantization layout is specified"
original_shape = x.shape
if group_sizes is not None:
assert not is_colwise, "Not yet implememted!"
assert group_sizes.ndim == 1, (
"GroupedQuantizer only support 1D group_sizes, got group_sizes.ndim ="
f" {group_sizes.ndim}"
)
_zeros = partial(jax.lax.full_like, fill_value=0)
x_iota = jax.lax.broadcasted_iota(group_sizes.dtype, x.shape, 0)
group_ends = jnp.cumulative_sum(group_sizes)
group_starts = jax.lax.concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]],
dimension=0,
)
x_zero = _zeros(x)
tensor_list = []
for i in range(len(group_sizes)):
mask = jax.lax.bitwise_and(group_starts[i] <= x_iota, x_iota < group_ends[i])
x_selected = jax.lax.select(mask, x, x_zero)
tensor = self.quantizers[i].quantize(
x_selected, is_rowwise, is_colwise, dq_dtype, flatten_axis
)
tensor_list.append(tensor)
combine_mode = 1 # Add
else:
group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)
x = jnp.split(x, x.shape[group_axis], axis=group_axis)
tensor_list = []
for i in range(len(group_sizes)):
tensor = self.quantizers[i].quantize(
x[i], is_rowwise, is_colwise, dq_dtype, flatten_axis
)
tensor_list.append(tensor)
combine_mode = 0 # Concate
grouped_rowwise_tensor = grouped_colwise_tensor = None
if is_rowwise:
rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list]
grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list(
rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
)
if is_colwise:
colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list]
grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list(
colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
)
if is_colwise and is_rowwise:
return ScaledTensor2x(grouped_rowwise_tensor, grouped_colwise_tensor)
if is_colwise:
return grouped_colwise_tensor
return grouped_rowwise_tensor
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, group_sizes=None):
assert group_sizes, "Empty group_sizes was given!"
return self.scaling_mode.get_grouped_scale_shape_2x(
data_shape, group_sizes, is_padded, flatten_axis
)
@dataclass
class QuantizerFactory:
"""Factory class for creating quantizers.
......@@ -611,6 +780,7 @@ class QuantizerFactory:
scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None,
q_layout: QuantizeLayout = None,
n_groups: int = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
......@@ -621,6 +791,7 @@ class QuantizerFactory:
q_dtype: Quantization data type
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
n_groups: Number of quantizers if GroupedQuantizer
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -628,13 +799,21 @@ class QuantizerFactory:
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
# import pdb; pdb.set_trace()
if n_groups:
if n_quantizers != 1:
warnings.warn(
"Using more than one GroupedQuantizer for a grouped input is not recommended"
)
quantizer_type = GroupedQuantizer
kwargs["n_groups"] = n_groups
else:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
if scaling_mode == ScalingMode.NO_SCALING:
quantizers = [None] * n_quantizers
else:
quantizers = []
for _ in range(n_quantizers):
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
......@@ -643,7 +822,9 @@ class QuantizerFactory:
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
@staticmethod
def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> QuantizerSet:
def _create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes.
Args:
......@@ -651,6 +832,7 @@ class QuantizerFactory:
fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization
n_groups
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -680,11 +862,13 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
)
q_dgrad = QuantizerFactory.create(
1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod
......@@ -694,6 +878,7 @@ class QuantizerFactory:
fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None,
n_groups: int = None,
**kwargs,
) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers.
......@@ -704,6 +889,7 @@ class QuantizerFactory:
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups:
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -717,7 +903,9 @@ class QuantizerFactory:
q_set = []
for _ in range(n_quantizer_sets):
q_set.append(
QuantizerFactory._create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs)
QuantizerFactory._create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
)
)
return q_set[0] if len(q_set) == 1 else tuple(q_set)
......
......@@ -15,6 +15,7 @@ from enum import Enum
from typing import Tuple, Dict
from functools import reduce
import operator
import numpy as np
from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class
......@@ -26,6 +27,11 @@ from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["QuantizeShardyRules", "ScalingMode"]
def DIVUP(a, b):
"Divide a by b and then round up"
return -(a // -b)
@dataclass
class QuantizeShardyRules:
"""Information necessary to shard scale tensors with Shardy.
......@@ -74,7 +80,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
The shape for scale tensors
"""
@abstractmethod
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
n_groups: Number of groups in grouped quantization
group_axis: The axis along which grouping is performed
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
The shape for scale tensors
"""
......@@ -127,9 +152,29 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns:
The shape for scale tensors - (1,)
"""
del data_shape, is_colwise
del is_colwise
if np.prod(data_shape) == 0:
return (0,)
return (1,)
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
del data_shape, group_axis, is_colwise
assert isinstance(n_groups, int)
return (n_groups,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
......@@ -276,6 +321,77 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape)
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for grouped scale tensors in this mode.
If padded: The estimiated maximal possible shape for grouped scale tensor is return instead.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
assert isinstance(n_groups, int)
block_alignment = self._block_alignment if is_padded else (1, 1)
if is_colwise:
block_y, block_x = self._block_dims
alignment_y, alignment_x = block_alignment
else:
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
)
assert (
data_shape[-1] % block_y == 0
), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
assert flattened_first_dim % block_x == 0, (
f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
f" {data_shape} - should be divisible by block_x {block_x}"
)
assert flattened_last_dim % block_y == 0, (
"Flattened last dim - mutiplication of"
f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
f" divisible by block_y {block_y}"
)
n_block_x = int(flattened_first_dim // block_x)
n_block_y = int(flattened_last_dim // block_y)
"""
Given the scale shape of [M, N], and G groups, and padding alignment (128, 4),
The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows.
Then:
max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128
max_padded_cols = DIVUP(N, 4) * 4
max_scale_size = max_padded_rows * max_padded_cols
"""
if is_padded:
n_block_x = (n_groups - 1) * alignment_x + DIVUP(
n_block_x - n_groups + 1, alignment_x
) * alignment_x
n_block_y = DIVUP(n_block_y, alignment_y) * alignment_y
return (n_block_x * n_block_y,)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules:
......@@ -404,6 +520,61 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
n_groups: Number of groups for grouped quantization
group_axis: The axis along which grouping is performed
is_padded: Whether to use padded shapes
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape = self.get_grouped_scale_shape(
data_shape,
n_groups,
group_axis,
is_colwise=False,
is_padded=is_padded,
flatten_axis=flatten_axis,
)
colwise_scale_shape = self.get_grouped_scale_shape(
data_shape,
n_groups,
group_axis,
is_colwise=True,
is_padded=is_padded,
flatten_axis=flatten_axis,
)
return (rowwise_scale_shape, colwise_scale_shape)
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self._get_impl().get_grouped_scale_shape(
data_shape,
n_groups,
group_axis,
is_colwise=is_colwise,
is_padded=is_padded,
flatten_axis=flatten_axis,
)
def is_tensor_scaling(self) -> bool:
"""Check if this scaling mode is per-tensor scaling.
......
......@@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer
from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)
......@@ -27,6 +27,7 @@ __all__ = [
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
"GroupedScaledTensor1x",
"ScaledTensorFactory",
"with_sharding_constraint_by_logical_axes",
]
......@@ -122,7 +123,7 @@ class ScaledTensor1x(ScaledTensor):
_dq_func: Callable
is_colwise: bool
data_layout: str
flatten_axis: int = -1
flatten_axis: int
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
......@@ -229,8 +230,12 @@ class ScaledTensor1x(ScaledTensor):
# axis_names were given for N layout, so needs to be transpose for T layout
if self.data_layout == "T":
assert self.flatten_axis > 0
flatten_axis = -self.flatten_axis
axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis])
assert len(logical_axis_names) == self.data.ndim
flatten_axis = self.data.ndim - self.flatten_axis
axis_names = (
*logical_axis_names[flatten_axis:],
*logical_axis_names[:flatten_axis],
)
else:
axis_names = logical_axis_names
......@@ -254,6 +259,116 @@ class ScaledTensor1x(ScaledTensor):
)
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
"""Grouped Quantizer for an array.
This class extends ScaledTensor1x to support quantization of an array in grouped manner,
where elements are grouped along a specified axis.
Attributes:
group_sizes: Array containing the size of each group
original_shape: The original shape of the tensor before grouping
group_axis: The axis along which grouping is performed (default: 0)
"""
group_sizes: jnp.ndarray
original_shape: Tuple
group_axis: int
def __init__(
self,
data,
scale_inv,
group_sizes,
scaling_mode,
dq_dtype,
_dq_func,
is_colwise,
data_layout,
flatten_axis,
original_shape,
group_axis=0,
):
self.group_sizes = group_sizes
self.original_shape = original_shape
self.group_axis = group_axis
super().__init__(
data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis
)
def __post_init__(self):
assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
assert self.data.ndim == 1, "Only support flattened data"
data_ndim = len(self.original_shape)
flatten_axis = data_ndim + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
assert (
0 < flatten_axis < data_ndim
), f"flatten_axis {flatten_axis} is out of bounds for data.ndim = {data_ndim}"
group_axis = (
len(self.original_shape) + self.group_axis if self.group_axis < 0 else self.group_axis
)
assert (
0 <= group_axis < data_ndim
), f"group_axis {group_axis} is out of bounds for shape {self.original_shape}"
if self.data_layout == "T":
if self.original_shape[0] == self.group_sizes.size:
self.original_shape = (
self.original_shape[0],
*self.original_shape[flatten_axis:],
*self.original_shape[1:flatten_axis],
)
flatten_axis = len(self.original_shape) - flatten_axis + 1
else:
self.original_shape = (
*self.original_shape[flatten_axis:],
*self.original_shape[:flatten_axis],
)
self.group_axis = flatten_axis
flatten_axis = len(self.original_shape) - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
self.original_shape,
self.group_sizes.size,
self.group_axis,
self.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
assert self.scale_inv.shape == expected_scale_shape, (
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
f" scale_inv, got {self.scale_inv.shape}"
)
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv, self.group_sizes)
aux_data = (
self.scaling_mode,
self.dq_dtype,
self._dq_func,
self.is_colwise,
self.data_layout,
self.flatten_axis,
self.original_shape,
self.group_axis,
)
return (children, aux_data)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
raise NotImplementedError
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
......@@ -342,6 +457,9 @@ class ScaledTensorFactory:
is_colwise=False,
data_layout="N",
flatten_axis=-1,
group_sizes=None,
original_shape=None,
group_axis=0,
):
"""Creates a single-scale quantized tensor.
......@@ -353,13 +471,41 @@ class ScaledTensorFactory:
is_colwise: Whether to use column-wise quantization (default: False)
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
group_sizes: Arra of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor1x instance
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None:
assert (
original_shape is not None
), "original_shape is not given for GroupedScaledTensor1x"
return GroupedScaledTensor1x(
data=data,
scale_inv=scale_inv,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=dequantizer.grouped_dequantize,
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
return ScaledTensor1x(
data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis
data,
scale_inv,
scaling_mode,
dq_dtype,
dequantizer.dequantize,
is_colwise,
data_layout,
flatten_axis,
)
@staticmethod
......@@ -372,6 +518,9 @@ class ScaledTensorFactory:
dq_dtype=jnp.bfloat16,
data_layout="NN",
flatten_axis=-1,
group_sizes=None,
original_shape=None,
group_axis=0,
):
"""Creates a double-scale quantized tensor.
......@@ -384,30 +533,37 @@ class ScaledTensorFactory:
dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor2x instance
"""
dq_func = Dequantizer.funcs.get(scaling_mode)
rowwise_tensor = ScaledTensor1x(
assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensorFactory.create_1x(
data,
scale_inv,
scaling_mode,
dq_dtype,
dq_func,
is_colwise=False,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
colwise_tensor = ScaledTensor1x(
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
dq_func,
is_colwise=True,
data_layout=data_layout[1],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -422,6 +578,9 @@ class ScaledTensorFactory:
data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
flatten_axis: int = -1,
group_sizes: jnp.ndarray = None,
original_shape: Tuple[int] = None,
group_axis: int = 0,
):
"""Creates a scaled tensor based on the quantization axis.
......@@ -434,6 +593,10 @@ class ScaledTensorFactory:
dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN")
q_layout: The quantization axis (default: ROWWISE)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
......@@ -448,9 +611,26 @@ class ScaledTensorFactory:
dq_dtype,
data_layout=data_layout,
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
is_colwise = q_layout == QuantizeLayout.COLWISE
if is_colwise:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
return ScaledTensorFactory.create_1x(
data,
scale_inv,
......@@ -459,6 +639,9 @@ class ScaledTensorFactory:
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
......@@ -472,6 +655,9 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns:
The tensor with applied sharding constraints
"""
if isinstance(x, GroupedScaledTensor1x):
raise NotImplementedError
if isinstance(x, ScaledTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment