"transformer_engine/pytorch/csrc/common.cpp" did not exist on "a5ba71f3f7379acad9c2292a289aa58ab8a489a8"
Unverified Commit 5f2b8310 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Scale swizzling via JAX transpose op (#2163)



* add swizzle in jax
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added outer_impl
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent a26a7f1f
......@@ -134,6 +134,13 @@ class BasePrimitive(metaclass=ABCMeta):
"""
return NotImplemented
@classmethod
def outer_impl(cls, *args, **kwargs):
"""
to describe implementation for outer primitive
"""
return cls.impl(*args, **kwargs)
@staticmethod
@abstractmethod
def batcher():
......@@ -196,7 +203,7 @@ def register_primitive(cls):
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl)
outer_p.def_impl(cls.outer_impl)
outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
......
......@@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
return lhs_q, rhs_q
@partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
"Swizzle scale_inv via JAX transpose ops"
original_shape = scale_inv.shape
shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:]))
if is_colwise:
scale_inv = jnp.transpose(scale_inv.reshape(shape_2d))
cols, rows = shape_2d
else:
rows, cols = shape_2d
reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4)
swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4))
return swizzled.reshape(original_shape)
class GemmPrimitive(BasePrimitive):
"""
Primitive for cuBLAS GEMM
......@@ -286,28 +301,18 @@ class GemmPrimitive(BasePrimitive):
)
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Need extra workspace for swizzled scale factors
lhs_swizzle_size = 0
rhs_swizzle_size = 0
swizzle_dtype = jnp.uint8
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_swizzle_size = lhs_scale_inv.size
rhs_swizzle_size = rhs_scale_inv.size
lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)
# Declare cuBLAS workspace
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace
return output, bias_grad, pre_gelu_out, workspace
@staticmethod
def outer_abstract(*args, **kwargs):
outputs = GemmPrimitive.abstract(*args, **kwargs)
return outputs[:-3] # discard workspace arrays
return outputs[:-1] # discard workspace array
@staticmethod
def lowering(
......@@ -374,24 +379,22 @@ class GemmPrimitive(BasePrimitive):
grad,
use_split_accumulator,
):
if scaling_mode.is_1d_block_scaling():
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
)
lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims)
rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1
lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv,
scaling_mode,
lhs.shape,
is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv,
scaling_mode,
rhs.shape,
is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
)
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
outputs = GemmPrimitive.inner_primitive.bind(
lhs,
......@@ -408,7 +411,39 @@ class GemmPrimitive(BasePrimitive):
grad=grad,
use_split_accumulator=use_split_accumulator,
)
return outputs[:-3] # discard workspace arrays
return outputs[:-1] # discard workspace array
@staticmethod
def outer_impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
return GemmPrimitive.impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
)
@staticmethod
def batcher(
......
......@@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
}
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode,
size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
......@@ -61,40 +61,6 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
}
return std::make_tuple(std::move(input), input_shape);
......@@ -103,21 +69,19 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
......@@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment