Unverified Commit ecaf3e21 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Fixes for JIT-able grouped_gemm (#1872)



* fixes for jittable grouped_quantize

* fixes for jittable grouped_gemm

* fix contracting_dim for wgrad gemm

* exclude jitted grouped_gemm from the unit test as it does not work cudaGraph

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 227961e6
......@@ -4,7 +4,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, value_and_grad
from functools import reduce
......@@ -13,7 +12,6 @@ import operator
from utils import (
assert_allclose,
assert_tree_like_allclose,
pytest_parametrize_wrapper,
)
from transformer_engine.jax.layernorm import layernorm
......@@ -682,6 +680,10 @@ class TestGroupedQuantize:
n_groups=n_groups,
)
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
)
......@@ -1281,6 +1283,16 @@ class TestGroupedDense:
dtype, input_shape, layout
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims,
# )
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
......@@ -1312,6 +1324,12 @@ class TestGroupedDense:
out_dtype, input_shape, layout
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# )
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
)
......@@ -1346,6 +1364,9 @@ class TestGroupedDense:
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
......@@ -1386,6 +1407,10 @@ class TestGroupedDense:
n_groups=group_sizes.size,
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
......
......@@ -9,7 +9,7 @@ import operator
import math
import jax
import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
......@@ -30,7 +30,7 @@ from ..quantize import (
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"]
num_cublas_streams = 4
num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None:
......@@ -103,10 +103,15 @@ class GroupedGemmPrimitive(BasePrimitive):
"""
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias
del lhs_scale_inv_aval, rhs_scale_inv_aval
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size
# JAX buffer pointers are 128-aligned
# 255 is added to the workspace size to ensure workspace ptr is 256-aligned
workspace_size += 255
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
# TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue
out_shape = (M, N)
if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N)
......
......@@ -839,9 +839,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
scale_inv,
colwise_scale_inv,
updated_amax,
_dbias,
_wkspace,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod
......@@ -975,7 +973,9 @@ def grouped_quantize(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes)
segment_ids = jnp.repeat(
jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
......@@ -1048,7 +1048,9 @@ def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes)
segment_ids = jnp.repeat(
jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
)
grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype)
......
......@@ -30,6 +30,7 @@
#include "extensions/misc.h"
#include "extensions/utils.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/multi_stream.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
......
......@@ -10,7 +10,6 @@
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
......@@ -58,14 +57,12 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto workspace_total_size = product(workspace->dimensions());
auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
// Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto workspace_ptr =
reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(workspace->untyped_data()) + 255) &
~static_cast<uintptr_t>(255));
auto workspace_total_size = product(workspace->dimensions()) - 255;
auto workspace_size = workspace_total_size / num_streams;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
......
......@@ -69,6 +69,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_num_compute_streams", &nvte_get_num_compute_streams);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes);
m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes);
......
......@@ -389,7 +389,7 @@ def _grouped_dense_bwd_rule(
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (1,)
x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
......
......@@ -131,22 +131,16 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
flatten_axis = (
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert self.flatten_axis > 0
assert (
0 < flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
0 < self.flatten_axis < len(self.data.shape)
), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis
)
if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, (
......@@ -291,6 +285,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
original_shape,
group_axis=0,
):
self.flatten_axis = flatten_axis
self.group_sizes = group_sizes
self.original_shape = original_shape
self.group_axis = group_axis
......@@ -301,44 +296,25 @@ class GroupedScaledTensor1x(ScaledTensor1x):
def __post_init__(self):
assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
assert self.data.ndim == 1, "Only support flattened data"
assert self.group_axis >= 0
assert self.flatten_axis > 0
data_ndim = len(self.original_shape)
flatten_axis = data_ndim + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
assert (
0 < flatten_axis < data_ndim
), f"flatten_axis {flatten_axis} is out of bounds for data.ndim = {data_ndim}"
0 < self.flatten_axis < data_ndim
), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
group_axis = (
len(self.original_shape) + self.group_axis if self.group_axis < 0 else self.group_axis
)
assert (
0 <= group_axis < data_ndim
), f"group_axis {group_axis} is out of bounds for shape {self.original_shape}"
if self.data_layout == "T":
if self.original_shape[0] == self.group_sizes.size:
self.original_shape = (
self.original_shape[0],
*self.original_shape[flatten_axis:],
*self.original_shape[1:flatten_axis],
)
flatten_axis = len(self.original_shape) - flatten_axis + 1
else:
self.original_shape = (
*self.original_shape[flatten_axis:],
*self.original_shape[:flatten_axis],
)
self.group_axis = flatten_axis
flatten_axis = len(self.original_shape) - flatten_axis
0 <= self.group_axis < data_ndim
), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
self.original_shape,
self.group_sizes.size,
self.group_axis,
self.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
flatten_axis=self.flatten_axis,
)
assert self.scale_inv.shape == expected_scale_shape, (
......@@ -479,10 +455,31 @@ class ScaledTensorFactory:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
"""
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None:
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
original_shape is not None
), "original_shape is not given for GroupedScaledTensor1x"
# Handling attrs of transposed tensors
group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
if data_layout == "T":
if original_shape[0] == group_sizes.size:
original_shape = (
original_shape[0],
*original_shape[flatten_axis:],
*original_shape[1:flatten_axis],
)
flatten_axis = len(original_shape) - flatten_axis + 1
else:
original_shape = (
*original_shape[flatten_axis:],
*original_shape[:flatten_axis],
)
group_axis = flatten_axis
flatten_axis = len(original_shape) - flatten_axis
return GroupedScaledTensor1x(
data=data,
scale_inv=scale_inv,
......@@ -497,6 +494,11 @@ class ScaledTensorFactory:
group_axis=group_axis,
)
# Handling attrs of transposed tensors
flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
if data_layout == "T":
flatten_axis = data.ndim - flatten_axis
return ScaledTensor1x(
data,
scale_inv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment