Unverified Commit 5f2b8310 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Scale swizzling via JAX transpose op (#2163)



* add swizzle in jax
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added outer_impl
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent a26a7f1f
...@@ -134,6 +134,13 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -134,6 +134,13 @@ class BasePrimitive(metaclass=ABCMeta):
""" """
return NotImplemented return NotImplemented
@classmethod
def outer_impl(cls, *args, **kwargs):
"""
to describe implementation for outer primitive
"""
return cls.impl(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def batcher(): def batcher():
...@@ -196,7 +203,7 @@ def register_primitive(cls): ...@@ -196,7 +203,7 @@ def register_primitive(cls):
outer_p = core.Primitive(name_of_wrapper_p()) outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p) dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl) outer_p.def_impl(cls.outer_impl)
outer_p.def_abstract_eval(cls.outer_abstract) outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
......
...@@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
return lhs_q, rhs_q return lhs_q, rhs_q
@partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise):
"Swizzle scale_inv via JAX transpose ops"
original_shape = scale_inv.shape
shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:]))
if is_colwise:
scale_inv = jnp.transpose(scale_inv.reshape(shape_2d))
cols, rows = shape_2d
else:
rows, cols = shape_2d
reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4)
swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4))
return swizzled.reshape(original_shape)
class GemmPrimitive(BasePrimitive): class GemmPrimitive(BasePrimitive):
""" """
Primitive for cuBLAS GEMM Primitive for cuBLAS GEMM
...@@ -286,28 +301,18 @@ class GemmPrimitive(BasePrimitive): ...@@ -286,28 +301,18 @@ class GemmPrimitive(BasePrimitive):
) )
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Need extra workspace for swizzled scale factors
lhs_swizzle_size = 0
rhs_swizzle_size = 0
swizzle_dtype = jnp.uint8
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_swizzle_size = lhs_scale_inv.size
rhs_swizzle_size = rhs_scale_inv.size
lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)
# Declare cuBLAS workspace # Declare cuBLAS workspace
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment. # necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256 workspace_size = get_cublas_workspace_size_bytes() + 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace return output, bias_grad, pre_gelu_out, workspace
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
outputs = GemmPrimitive.abstract(*args, **kwargs) outputs = GemmPrimitive.abstract(*args, **kwargs)
return outputs[:-3] # discard workspace arrays return outputs[:-1] # discard workspace array
@staticmethod @staticmethod
def lowering( def lowering(
...@@ -374,24 +379,22 @@ class GemmPrimitive(BasePrimitive): ...@@ -374,24 +379,22 @@ class GemmPrimitive(BasePrimitive):
grad, grad,
use_split_accumulator, use_split_accumulator,
): ):
if scaling_mode.is_1d_block_scaling():
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout( lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
) )
lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims)
rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1
lhs_scale_inv = apply_padding_to_scale_inv( lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv, lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis
scaling_mode,
lhs.shape,
is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
) )
rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis
scaling_mode,
rhs.shape,
is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
) )
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
outputs = GemmPrimitive.inner_primitive.bind( outputs = GemmPrimitive.inner_primitive.bind(
lhs, lhs,
...@@ -408,7 +411,39 @@ class GemmPrimitive(BasePrimitive): ...@@ -408,7 +411,39 @@ class GemmPrimitive(BasePrimitive):
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
) )
return outputs[:-3] # discard workspace arrays return outputs[:-1] # discard workspace array
@staticmethod
def outer_impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
return GemmPrimitive.impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
)
@staticmethod @staticmethod
def batcher( def batcher(
......
...@@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ...@@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
} }
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape // Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions(); auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary), std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
...@@ -61,40 +61,6 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -61,40 +61,6 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
} else { } else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} }
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
} }
return std::make_tuple(std::move(input), input_shape); return std::make_tuple(std::move(input), input_shape);
...@@ -103,21 +69,19 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -103,21 +69,19 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode,
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); rhs_axis_boundary, make_rhs_rowwise);
// Output tensor // Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
...@@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Ret<Buffer_Type>() // output .Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad .Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out .Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace .Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary") .Attr<int64_t>("lhs_axis_boundary")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment