Unverified Commit ecaf3e21 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Fixes for JIT-able grouped_gemm (#1872)



* fixes for jittable grouped_quantize

* fixes for jittable grouped_gemm

* fix contracting_dim for wgrad gemm

* exclude jitted grouped_gemm from the unit test as it does not work cudaGraph

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 227961e6
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import pytest import pytest
from jax import jit, value_and_grad from jax import jit, value_and_grad
from functools import reduce from functools import reduce
...@@ -13,7 +12,6 @@ import operator ...@@ -13,7 +12,6 @@ import operator
from utils import ( from utils import (
assert_allclose, assert_allclose,
assert_tree_like_allclose,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
) )
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
...@@ -682,6 +680,10 @@ class TestGroupedQuantize: ...@@ -682,6 +680,10 @@ class TestGroupedQuantize:
n_groups=n_groups, n_groups=n_groups,
) )
# grouped_quantize does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
scaled_tensor = tex.grouped_quantize( scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
) )
...@@ -1281,6 +1283,16 @@ class TestGroupedDense: ...@@ -1281,6 +1283,16 @@ class TestGroupedDense:
dtype, input_shape, layout dtype, input_shape, layout
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# grouped_gemm does not work with cudaGraph yet, so the jitting will breaks
# To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to
# disable cudaGraph, then use the following jitted function
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
# lhs, rhs, group_sizes, contracting_dims,
# )
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
...@@ -1312,6 +1324,12 @@ class TestGroupedDense: ...@@ -1312,6 +1324,12 @@ class TestGroupedDense:
out_dtype, input_shape, layout out_dtype, input_shape, layout
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm
# prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))(
# lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# )
prim_out = tex.grouped_gemm( prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
...@@ -1346,6 +1364,9 @@ class TestGroupedDense: ...@@ -1346,6 +1364,9 @@ class TestGroupedDense:
) )
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
...@@ -1386,6 +1407,10 @@ class TestGroupedDense: ...@@ -1386,6 +1407,10 @@ class TestGroupedDense:
n_groups=group_sizes.size, n_groups=group_sizes.size,
) )
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
# jitting the grouped_dense
# value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)),
# static_argnums=(4,))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
......
...@@ -9,7 +9,7 @@ import operator ...@@ -9,7 +9,7 @@ import operator
import math import math
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize from .quantization import grouped_quantize
...@@ -30,7 +30,7 @@ from ..quantize import ( ...@@ -30,7 +30,7 @@ from ..quantize import (
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] __all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"]
num_cublas_streams = 4 num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
...@@ -103,10 +103,15 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -103,10 +103,15 @@ class GroupedGemmPrimitive(BasePrimitive):
""" """
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias
del lhs_scale_inv_aval, rhs_scale_inv_aval
# TODO(Phuong): move some shape checks from Cpp to here # TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size # JAX buffer pointers are 128-aligned
# 255 is added to the workspace size to ensure workspace ptr is 256-aligned
workspace_size += 255
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
# TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue
out_shape = (M, N) out_shape = (M, N)
if is_grouped_dense_wgrad: if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N) out_shape = (group_sizes_aval.size, M, N)
......
...@@ -839,9 +839,7 @@ class GroupedQuantizePrimitive(BasePrimitive): ...@@ -839,9 +839,7 @@ class GroupedQuantizePrimitive(BasePrimitive):
scale_inv, scale_inv,
colwise_scale_inv, colwise_scale_inv,
updated_amax, updated_amax,
_dbias, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
_wkspace,
) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod @staticmethod
...@@ -975,7 +973,9 @@ def grouped_quantize( ...@@ -975,7 +973,9 @@ def grouped_quantize(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
segment_ids = jnp.repeat(jnp.arange(n_groups), group_sizes) segment_ids = jnp.repeat(
jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
)
grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
for i in range(n_groups): for i in range(n_groups):
tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
...@@ -1048,7 +1048,9 @@ def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray: ...@@ -1048,7 +1048,9 @@ def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
assert grad.ndim == 2, "Input grad must be a 2D tensor." assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor." assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes) segment_ids = jnp.repeat(
jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
)
grad_fp32 = grad.astype(jnp.float32) grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0]) dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype) dbias = dbias_fp32.astype(grad.dtype)
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "extensions/misc.h" #include "extensions/misc.h"
#include "extensions/utils.h" #include "extensions/utils.h"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/multi_stream.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "../extensions.h" #include "../extensions.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32 #define MXFP8_BLOCK_SIZE 32
...@@ -58,14 +57,12 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -58,14 +57,12 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// Outputs // Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data()); auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data()); // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto workspace_total_size = product(workspace->dimensions()); auto workspace_ptr =
reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(workspace->untyped_data()) + 255) &
auto lhs_sinv_size = product(lhs_sinv.dimensions()); ~static_cast<uintptr_t>(255));
auto rhs_sinv_size = product(rhs_sinv.dimensions()); auto workspace_total_size = product(workspace->dimensions()) - 255;
auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams; auto workspace_size = workspace_total_size / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
......
...@@ -69,6 +69,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -69,6 +69,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_num_compute_streams", &nvte_get_num_compute_streams);
m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes); m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes);
m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes); m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes);
......
...@@ -389,7 +389,7 @@ def _grouped_dense_bwd_rule( ...@@ -389,7 +389,7 @@ def _grouped_dense_bwd_rule(
# after the extra transpose for FP8 in grouped_gemm # after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,) g_contracting_dim = (0,)
x_contracting_dim = (1,) x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor() wgrad_grad = casted_grad.get_colwise_tensor()
......
...@@ -131,22 +131,16 @@ class ScaledTensor1x(ScaledTensor): ...@@ -131,22 +131,16 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary. and quantization direction. Pads the scale_inv if necessary.
""" """
flatten_axis = ( assert self.flatten_axis > 0
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert ( assert (
0 < flatten_axis < len(self.data.shape) 0 < self.flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}" ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_scale_shape( expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis
) )
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis
) )
if self.scale_inv.shape != expected_scale_shape: if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, ( assert self.scale_inv.shape == expected_unpadded_scale_shape, (
...@@ -291,6 +285,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -291,6 +285,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
original_shape, original_shape,
group_axis=0, group_axis=0,
): ):
self.flatten_axis = flatten_axis
self.group_sizes = group_sizes self.group_sizes = group_sizes
self.original_shape = original_shape self.original_shape = original_shape
self.group_axis = group_axis self.group_axis = group_axis
...@@ -301,44 +296,25 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -301,44 +296,25 @@ class GroupedScaledTensor1x(ScaledTensor1x):
def __post_init__(self): def __post_init__(self):
assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
assert self.data.ndim == 1, "Only support flattened data" assert self.data.ndim == 1, "Only support flattened data"
assert self.group_axis >= 0
assert self.flatten_axis > 0
data_ndim = len(self.original_shape) data_ndim = len(self.original_shape)
flatten_axis = data_ndim + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
assert ( assert (
0 < flatten_axis < data_ndim 0 < self.flatten_axis < data_ndim
), f"flatten_axis {flatten_axis} is out of bounds for data.ndim = {data_ndim}" ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
group_axis = (
len(self.original_shape) + self.group_axis if self.group_axis < 0 else self.group_axis
)
assert ( assert (
0 <= group_axis < data_ndim 0 <= self.group_axis < data_ndim
), f"group_axis {group_axis} is out of bounds for shape {self.original_shape}" ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
if self.data_layout == "T":
if self.original_shape[0] == self.group_sizes.size:
self.original_shape = (
self.original_shape[0],
*self.original_shape[flatten_axis:],
*self.original_shape[1:flatten_axis],
)
flatten_axis = len(self.original_shape) - flatten_axis + 1
else:
self.original_shape = (
*self.original_shape[flatten_axis:],
*self.original_shape[:flatten_axis],
)
self.group_axis = flatten_axis
flatten_axis = len(self.original_shape) - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
self.original_shape, self.original_shape,
self.group_sizes.size, self.group_sizes.size,
self.group_axis, self.group_axis,
self.is_colwise, self.is_colwise,
is_padded=True, is_padded=True,
flatten_axis=flatten_axis, flatten_axis=self.flatten_axis,
) )
assert self.scale_inv.shape == expected_scale_shape, ( assert self.scale_inv.shape == expected_scale_shape, (
...@@ -479,10 +455,31 @@ class ScaledTensorFactory: ...@@ -479,10 +455,31 @@ class ScaledTensorFactory:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
""" """
dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None: if group_sizes is not None:
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert ( assert (
original_shape is not None original_shape is not None
), "original_shape is not given for GroupedScaledTensor1x" ), "original_shape is not given for GroupedScaledTensor1x"
# Handling attrs of transposed tensors
group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
if data_layout == "T":
if original_shape[0] == group_sizes.size:
original_shape = (
original_shape[0],
*original_shape[flatten_axis:],
*original_shape[1:flatten_axis],
)
flatten_axis = len(original_shape) - flatten_axis + 1
else:
original_shape = (
*original_shape[flatten_axis:],
*original_shape[:flatten_axis],
)
group_axis = flatten_axis
flatten_axis = len(original_shape) - flatten_axis
return GroupedScaledTensor1x( return GroupedScaledTensor1x(
data=data, data=data,
scale_inv=scale_inv, scale_inv=scale_inv,
...@@ -497,6 +494,11 @@ class ScaledTensorFactory: ...@@ -497,6 +494,11 @@ class ScaledTensorFactory:
group_axis=group_axis, group_axis=group_axis,
) )
# Handling attrs of transposed tensors
flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
if data_layout == "T":
flatten_axis = data.ndim - flatten_axis
return ScaledTensor1x( return ScaledTensor1x(
data, data,
scale_inv, scale_inv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment