Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -46,6 +46,8 @@ std::string to_string(const DType type) { ...@@ -46,6 +46,8 @@ std::string to_string(const DType type) {
return "Float8E8M0"; return "Float8E8M0";
case DType::kFloat4E2M1: case DType::kFloat4E2M1:
return "Float4E2M1"; return "Float4E2M1";
case DType::kInt16:
return "Int16";
case DType::kInt32: case DType::kInt32:
return "Int32"; return "Int32";
case DType::kInt64: case DType::kInt64:
......
...@@ -936,17 +936,20 @@ template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> ...@@ -936,17 +936,20 @@ template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(),
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(),
NVTE_CHECK(input.data.shape[0] == output->data.shape[0], ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
"Input shape[0] must be equal to output shape[0]."); NVTE_CHECK(input.flat_last_dim() % 2 == 0,
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
"Input shape[1] must be 2x larger than output shape[1]."); input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2,
"Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2,
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output->data.dtype, OType, output->dtype(), OType,
if (!is_fp8_dtype(output->data.dtype) || if (!is_fp8_dtype(output->data.dtype) ||
is_delayed_tensor_scaling(output->scaling_mode)) { is_delayed_tensor_scaling(output->scaling_mode)) {
...@@ -956,8 +959,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -956,8 +959,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<OType *>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), output->data.shape[0], reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->data.shape[1], {}, stream); output->flat_last_dim(), {}, stream);
} else { } else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*) }); // NOLINT(*)
......
...@@ -123,8 +123,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id) ...@@ -123,8 +123,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id)
bool supports_multicast(int device_id) { bool supports_multicast(int device_id) {
#if CUDART_VERSION >= 12010 #if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile time because the // NOTE: This needs to be guarded at compile-time and run-time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
if (cudart_version() < 12010) {
return false;
}
static std::vector<bool> cache(num_devices(), false); static std::vector<bool> cache(num_devices(), false);
static std::vector<std::once_flag> flags(num_devices()); static std::vector<std::once_flag> flags(num_devices());
if (device_id < 0) { if (device_id < 0) {
...@@ -219,6 +222,16 @@ const std::string &include_directory(bool required) { ...@@ -219,6 +222,16 @@ const std::string &include_directory(bool required) {
} }
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
int cudart_version() {
auto get_version = []() -> int {
int version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&version));
return version;
};
static int version = get_version();
return version;
}
} // namespace cuda } // namespace cuda
} // namespace transformer_engine } // namespace transformer_engine
...@@ -79,6 +79,12 @@ bool supports_multicast(int device_id = -1); ...@@ -79,6 +79,12 @@ bool supports_multicast(int device_id = -1);
const std::string &include_directory(bool required = false); const std::string &include_directory(bool required = false);
#endif #endif
/* \brief CUDA Runtime version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
int cudart_version();
} // namespace cuda } // namespace cuda
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -113,4 +113,12 @@ int get_num_compute_streams() { ...@@ -113,4 +113,12 @@ int get_num_compute_streams() {
int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); } int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); }
cudaStream_t nvte_get_compute_stream(const int idx) {
return transformer_engine::detail::get_compute_stream(idx);
}
cudaEvent_t nvte_get_compute_stream_event(const int idx) {
return transformer_engine::detail::get_compute_stream_event(idx);
}
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_ #endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
...@@ -126,6 +126,83 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP ...@@ -126,6 +126,83 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
} }
} }
template <int nvec, typename Type>
__global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(MultiPaddingArgs args) {
using Vec = Vec<Type, nvec>;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr int bdimx = THREADS_PER_WARP;
constexpr int bdimy = n_warps_per_tile;
const int tid = threadIdx.x;
const int tidx = tid % bdimx;
const int tidy = tid / bdimx;
const int bid = blockIdx.x;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr int tile_dim_m = THREADS_PER_WARP * nvec;
constexpr int tile_dim_n = THREADS_PER_WARP * nvec;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]);
Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
const int row_length = args.row_length_list[tensor_id];
// Find position of tile within tensor
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int tile_id = bid - args.block_range[tensor_id];
const int tile_id_m = tile_id / num_tiles_n;
const int tile_id_n = tile_id % num_tiles_n;
const int tile_row = tile_id_m * tile_dim_m;
const int tile_col = tile_id_n * tile_dim_n;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type local_zero = static_cast<Type>(0.f);
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
for (int i2 = 0; i2 < nvec; ++i2) {
const int row = tile_row + i1 * nvec + i2;
const int col = tile_col + j1 * nvec;
Vec local_input;
Vec local_output;
local_input.clear();
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
}
}
}
#pragma unroll
for (int j2 = 0; j2 < nvec; ++j2) {
local_output.data.elt[j2] = local_input.data.elt[j2];
}
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
}
}
}
}
}
}
} // namespace } // namespace
void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list, void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
...@@ -202,6 +279,78 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -202,6 +279,78 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
} }
} }
void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> output_list,
const std::vector<int> unpadded_num_rows_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(output_list.size() == input_list.size(),
"Number of input and output tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType type = input_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& output = *output_list[tensor_id];
CheckInputTensor(input, "multi_unpadding_input_" + std::to_string(tensor_id));
CheckInputTensor(output, "multi_unpadding_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == type, "Input tensor types do not match.");
NVTE_CHECK(output.data.dtype == type, "Output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(output.data.shape[0] == unpadded_num_rows_list[tensor_id],
"output tensor shape does not match padded input shape.");
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
kernel_args.num_tensors = 0;
kernel_args.block_range[0] = 0;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
kernel_args.num_tensors = 0;
}
// Calculate number of thread blocks needed for tensor
const int num_rows = unpadded_num_rows_list[tensor_id];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_list[pos] = output_list[tensor_id]->data.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
type, Type, constexpr int nvec = desired_load_store_size / sizeof(Type);
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
}
}
} // namespace transformer_engine } // namespace transformer_engine
void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list, void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
...@@ -217,3 +366,17 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe ...@@ -217,3 +366,17 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
} }
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
} }
void nvte_multi_unpadding(size_t num_tensors, const NVTETensor* input_list, NVTETensor* output_list,
const int* unpadded_num_rows_list, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_unpadding);
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> unpadded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
unpadded_num_rows_list_.push_back(unpadded_num_rows_list[i]);
}
multi_unpadding(input_list_, output_list_, unpadded_num_rows_list_, stream);
}
...@@ -156,7 +156,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& ...@@ -156,7 +156,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
#ifndef USE_ROCM #ifndef USE_ROCM
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch()); const int compile_sm_arch = std::min(sm_arch_, max_supported_sm_arch());
const bool compile_ptx = (CUDA_VERSION <= 11000) || (sm_arch_ != compile_sm_arch); const bool compile_ptx = sm_arch_ != compile_sm_arch;
#endif // USE_ROCM #endif // USE_ROCM
// Compilation flags // Compilation flags
......
...@@ -85,6 +85,13 @@ class _Buffer: ...@@ -85,6 +85,13 @@ class _Buffer:
if self.modified[0] and not self.reduce_within_microbatch: if self.modified[0] and not self.reduce_within_microbatch:
return return
if (
tensor.numel() == 0
if hasattr(tensor, "numel")
else all((t is None or t.numel() == 0) for t in tensor.get_data_tensors())
):
return
# save stats for tensor to tmp buffer # save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute: for stat_name in self.stats_to_compute:
fn, _ = STATS[stat_name] fn, _ = STATS[stat_name]
......
...@@ -17,6 +17,8 @@ def _compute_dynamic_range_top(tensor): ...@@ -17,6 +17,8 @@ def _compute_dynamic_range_top(tensor):
"""Computes the log2 of the amax of the tensor""" """Computes the log2 of the amax of the tensor"""
tensor_abs = tensor.abs() tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0] tensor_abs = tensor_abs[tensor_abs != 0]
if tensor_abs.numel() == 0:
return torch.inf
amax = tensor_abs.max().float() amax = tensor_abs.max().float()
if not amax.all(): if not amax.all():
amax = torch.tensor(1, device=tensor.device).to(torch.float) amax = torch.tensor(1, device=tensor.device).to(torch.float)
...@@ -125,7 +127,7 @@ STATS = { ...@@ -125,7 +127,7 @@ STATS = {
lambda buffers: min(_get(buffers, "dynamic_range_bottom")), lambda buffers: min(_get(buffers, "dynamic_range_bottom")),
), ),
"underflows_num": ( "underflows_num": (
lambda x: (x._data == 0).sum(), lambda x: (x.get_data_tensors()[0] == 0).sum(),
lambda buffers: sum(_get(buffers, "underflows_num")), lambda buffers: sum(_get(buffers, "underflows_num")),
), ),
"std": ( "std": (
......
...@@ -62,6 +62,12 @@ class DebugQuantizer(Quantizer): ...@@ -62,6 +62,12 @@ class DebugQuantizer(Quantizer):
self.tp_group = tp_group # used in inspect_tensor calls self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
# .internal = True is slightly faster, but results
# in errors when caching the weights.
# Setting .internal = False is safer.
if parent_quantizer is not None:
parent_quantizer.internal = False
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
......
...@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive): ...@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_"
x_rank = len(value_types[0].shape) x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2 x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
out = (*x_axes[:-2], x_axes[-1]) out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple( colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
) )
else: else:
colwise_out = out colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor. # amax is always a unit tensor.
amax = ("l",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
( (
x_axes, x_axes,
"…1", ("…1",),
), ),
(out, colwise_out, scale_inv, colwise_scale_inv, amax), (out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
) )
...@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
prefix = "BaseDActLuDBiasQuantizePrimitive_"
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2 len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else: else:
colwise_out = tuple(x_axes) colwise_out = out
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",) dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = ("…4",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)), (dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -985,6 +981,7 @@ def act_lu( ...@@ -985,6 +981,7 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -993,6 +990,7 @@ def act_lu( ...@@ -993,6 +990,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1037,6 +1035,10 @@ def act_lu( ...@@ -1037,6 +1035,10 @@ def act_lu(
is_outer=True, is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
if noop_scaled_tensor:
return ScaledTensorFactory.create_2x(
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype
)
return out return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
...@@ -1090,6 +1092,7 @@ def quantize_dact_dbias( ...@@ -1090,6 +1092,7 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1100,6 +1103,7 @@ def quantize_dact_dbias( ...@@ -1100,6 +1103,7 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
...@@ -1113,13 +1117,49 @@ def quantize_dact_dbias( ...@@ -1113,13 +1117,49 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}" f" {x.shape} and act_len {act_len}"
) )
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled(): if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet if quantizer is None:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
output = output.astype(x.dtype)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
)
return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
...@@ -1145,31 +1185,6 @@ def quantize_dact_dbias( ...@@ -1145,31 +1185,6 @@ def quantize_dact_dbias(
if war_output is not None: if war_output is not None:
return war_output return war_output
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu( out = dact_lu(
...@@ -1183,7 +1198,7 @@ def quantize_dact_dbias( ...@@ -1183,7 +1198,7 @@ def quantize_dact_dbias(
) )
return out, dbias return out, dbias
if isinstance(quantizer, DelayedScaleQuantizer): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# TE/common dact_dbias_quantize does not support gated act yet # TE/common dact_dbias_quantize does not support gated act yet
...@@ -1243,6 +1258,7 @@ def dact_lu( ...@@ -1243,6 +1258,7 @@ def dact_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1252,6 +1268,7 @@ def dact_lu( ...@@ -1252,6 +1268,7 @@ def dact_lu(
x: Input tensor that was used in forward pass. x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied. activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient. quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
The gradient of the activation with respect to the input. The gradient of the activation with respect to the input.
...@@ -1262,5 +1279,6 @@ def dact_lu( ...@@ -1262,5 +1279,6 @@ def dact_lu(
activation_type=activation_type, activation_type=activation_type,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
) )
return output return output
...@@ -3,19 +3,26 @@ ...@@ -3,19 +3,26 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX te modules""" """JAX te modules"""
from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce
import operator
import math import math
import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from functools import partial, reduce
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax as tex
from transformer_engine_jax import get_num_compute_streams
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize from .quantization import grouped_quantize
from ..quantize import ( from ..quantize import (
ScaledTensor, ScaledTensor,
ScaledTensor2x,
GroupedScaledTensor1x, GroupedScaledTensor1x,
ScalingMode, ScalingMode,
Quantizer, Quantizer,
...@@ -24,10 +31,20 @@ from ..quantize import ( ...@@ -24,10 +31,20 @@ from ..quantize import (
QuantizerSet, QuantizerSet,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
) )
from .misc import get_padded_spec
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] __all__ = [
"gemm",
"grouped_gemm",
"gemm_uses_jax_dot",
"sanitize_dims",
"get_non_contracting_dims",
"transpose_dims",
]
num_cublas_streams = get_num_compute_streams() num_cublas_streams = get_num_compute_streams()
...@@ -35,14 +52,924 @@ num_cublas_streams = get_num_compute_streams() ...@@ -35,14 +52,924 @@ num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures.""" """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if get_device_compute_capability(0) >= 90: if tex.get_device_compute_capability(0) >= 90:
return 33_554_432 return 33_554_432
return 4_194_304 return 4_194_304
def is_gemm_with_all_layouts_supported() -> False: def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]:
"""Return True if using blackwell, False otherwise.""" """Convert relative (negative) indexes to absolute dimension numbers."""
return get_device_compute_capability(0) >= 100 dims_ = dims if isinstance(dims, Iterable) else (dims,)
if len(dims_) == 0:
return dims_
return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None)
def get_non_contracting_dims(ndim, contracting_dims):
"""Return a tuple of dimensions not included in the contracting dimensions."""
contracting_dims = sanitize_dims(ndim, contracting_dims)
return tuple(dim for dim in range(ndim) if dim not in contracting_dims)
def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1):
"""Compute the new dimension numbers after transpose."""
if len(dims_to_transpose) == 0:
return dims_to_transpose
flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis
transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis))
return tuple(transposed_dims.index(dim) for dim in dims_to_transpose)
def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool:
lhs, rhs, e4m3, e5m2 = map(
dtypes.canonicalize_dtype,
(
lhs_dtype,
rhs_dtype,
jnp.float8_e4m3fn,
jnp.float8_e5m2,
),
)
# FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3)
if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3):
return True
# Any other combination of data types is not supported
return False
def _get_gemm_layout(
operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]]
) -> Tuple[bool, bool]:
lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims)
lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting
rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting
return lhs_is_transposed, rhs_is_transposed
def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims):
lhs_q = lhs
rhs_q = rhs
if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None:
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0])
lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims
need_lhs_colwise = lhs_is_transposed and (
lhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
)
flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=not need_lhs_colwise,
is_colwise=need_lhs_colwise,
flatten_axis=flatten_axis,
)
if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None:
rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1])
rhs_is_transposed = rhs.ndim - 1 in rhs_cdims
need_rhs_colwise = not rhs_is_transposed and (
rhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
)
flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=not need_rhs_colwise,
is_colwise=need_rhs_colwise,
flatten_axis=flatten_axis,
)
assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x)
return lhs_q, rhs_q
class GemmPrimitive(BasePrimitive):
"""
Primitive for cuBLAS GEMM
"""
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
def _dims_are_consecutive(dims):
if len(dims) <= 1:
return True
return sorted(dims) == list(range(min(dims), max(dims) + 1))
# Sanity-check operand layouts and types
operand_ndims = (lhs.ndim, rhs.ndim)
(
lhs_contracting_dims,
rhs_contracting_dims,
) = map(sanitize_dims, operand_ndims, contracting_dims)
assert _dims_are_consecutive(lhs_contracting_dims), (
"cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got "
f"{lhs_contracting_dims}."
)
assert _dims_are_consecutive(rhs_contracting_dims), (
"cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got "
f"{rhs_contracting_dims}."
)
(
lhs_batch_dims,
rhs_batch_dims,
) = map(sanitize_dims, operand_ndims, batched_dims)
assert _dims_are_consecutive(lhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f"{lhs_batch_dims}."
)
assert _dims_are_consecutive(rhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f"{rhs_batch_dims}."
)
if len(lhs_batch_dims) == 0:
assert (
len(rhs_batch_dims) == 0
), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif len(rhs_batch_dims) != 0:
assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
bdim in rhs_contracting_dims for bdim in rhs_batch_dims
), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape),
(lhs_contracting_dims, rhs_contracting_dims),
)
assert lhs_contracting_size == rhs_contracting_size, (
"cuBLAS GEMM operands have incompatible contracting dimensions: "
f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}."
)
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
if scaling_mode != ScalingMode.NO_SCALING:
assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), (
"cuBLAS GEMM quantized operands have incompatible data types: "
f"{lhs.dtype} x {rhs.dtype}."
)
assert (
lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0
), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
if (
scaling_mode != ScalingMode.MXFP8_1D_SCALING
and not tex.is_non_nt_fp8_gemm_supported()
):
assert not lhs_is_transposed and rhs_is_transposed, (
"cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
"require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)."
)
# Determine output shape and dtype
assert (
dtypes.canonicalize_dtype(out_dtype).itemsize > 1
), "cuBLAS GEMM custom op does not support 8-bit quantized output types."
lhs_non_contracting_shape, rhs_non_contracting_shape = map(
lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims],
(lhs.shape, rhs.shape),
(lhs_contracting_dims, rhs_contracting_dims),
)
out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
# Validate bias
bias_shape = (0,)
bias_dtype = out_dtype
if fuse_bias:
expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape)
if not grad:
assert bias.size == expected_bias_size, (
"cuBLAS GEMM bias tensor has incorrect shape, "
f"expected ({expected_bias_size}, ) but found {bias.shape}."
)
assert bias.dtype == out_dtype, (
"cuBLAS GEMM bias tensor has incorrect data type, "
f"expected {bias_dtype} but found {bias.dtype}."
)
bias_shape = bias.shape
else:
bias_shape = rhs_non_contracting_shape
bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype)
# Validate pre-GeLU
pre_gelu_shape = (0,)
pre_gelu_dtype = out_dtype
if fuse_gelu:
pre_gelu_shape = out_shape
if grad:
pre_gelu_ndim = len(pre_gelu_shape)
assert gelu_input.ndim == pre_gelu_shape and all(
gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)
), (
"cuBLAS GEMM pre-GeLU tensor has incorrect shape, "
f"expected {pre_gelu_shape} but found {gelu_input.shape}."
)
assert gelu_input.dtype == out_dtype, (
"cuBLAS GEMM pre-GeLU tensor has incorrect data type, "
f"expected {pre_gelu_dtype} but found {gelu_input.dtype}."
)
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Need extra workspace for swizzled scale factors
lhs_swizzle_size = 0
rhs_swizzle_size = 0
swizzle_dtype = jnp.uint8
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_swizzle_size = lhs_scale_inv.size
rhs_swizzle_size = rhs_scale_inv.size
lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)
# Declare cuBLAS workspace
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace
@staticmethod
def outer_abstract(*args, **kwargs):
outputs = GemmPrimitive.abstract(*args, **kwargs)
return outputs[:-3] # discard workspace arrays
@staticmethod
def lowering(
ctx,
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input)
kwargs = {
"scaling_mode": int(scaling_mode.value),
"lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
"rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
"lhs_transposed": lhs_transposed,
"rhs_transposed": rhs_transposed,
"fuse_bias": fuse_bias,
"fuse_gelu": fuse_gelu,
"grad": grad,
"use_split_accumulator": use_split_accumulator,
}
operand_output_aliases = {}
if fuse_bias and not grad:
operand_output_aliases.update({4: 1}) # bias <-> bias_grad
if fuse_gelu and grad:
operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out
return jax.ffi.ffi_lowering(
GemmPrimitive.name,
operand_output_aliases=operand_output_aliases,
)(ctx, *args, **kwargs)
@staticmethod
def impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
)
lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv,
scaling_mode,
lhs.shape,
is_colwise=lhs_quantized_colwise,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv,
scaling_mode,
rhs.shape,
is_colwise=rhs_quantized_colwise,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
)
outputs = GemmPrimitive.inner_primitive.bind(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
)
return outputs[:-3] # discard workspace arrays
@staticmethod
def batcher(
batched_args,
jax_batch_dims,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args
lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
# Output is batched like the non-contracting batch dimensions of the LHS operand
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims)
lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims)
out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims
# Bias gradient is never batched
bias_bdims = (None,)
# Pre-GeLU output, if exists, is batched like GEMM output
pre_gelu_bdims = (None,)
if fuse_gelu and not grad:
pre_gelu_bdims = out_bdims
return (
GemmPrimitive.outer_primitive.bind(
*batched_args,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
),
(out_bdims, bias_bdims, pre_gelu_bdims),
)
@staticmethod
def _decompose_operand_specs(specs, contracting_dims, batch_dims):
ndims = len(specs)
cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims))
# Batch specs
bspecs = tuple(specs[i] for i in bdims)
# Non-batch leading dimension specs
lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims)
# Non-batch contracting dimension specs
cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims)
return bspecs, lspecs, cspecs
@staticmethod
def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map(
sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims
)
(
(lhs_bspecs, lhs_lspecs, lhs_cspecs),
(rhs_bspecs, rhs_lspecs, rhs_cspecs),
) = map(
GemmPrimitive._decompose_operand_specs,
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
# Batched dimensions must have the same sharding
if len(lhs_bdims) > 0 and len(rhs_bdims) > 0:
assert all(
lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs)
), (
"cuBLAS GEMM operand batch dimensions must have the same sharding: "
f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}."
)
# Only one each of the non-batched leading dimensions and non-batched contracting
# dimensions can be sharded
lhs_ldims, rhs_ldims = map(
lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude),
(lhs_ndim, rhs_ndim),
(lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims),
)
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map(
lambda specs: tuple(spec for spec in specs if spec is not None),
(lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs),
)
assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched leading dimension: "
f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}."
)
assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: "
f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}."
)
# Extract single leading and contracting dimension specs
(lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map(
lambda specs: None if len(specs) == 0 else specs[0],
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none),
)
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands.
# 1. K1 == K2 != None and N == None
# LHS: (B, M, K)
# RHS: (B, None, K)
# OUT: (B, M, None) --(AR)-> (B, M, None)
# 2. K1 == K2 != None and M == N != None
# LHS: (B, M, K)
# RHS: (B, N, K)--(AG)->(B, None, K)
# OUT: (B, M, None) --(RS)--> (B, M, N)
# 3. M == N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, M, K)--(AG)->(B, None, None)
# OUT: (B, M, None)
# 4. M != N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, N, K)--(AG)->(B, N, None)
# OUT: (B, M, N)
reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec
all_reduce_output = reduce_flag and rhs_lspec is None
reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec
all_reduce_spec = reduce_scatter_spec = scatter_dim = None
lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
)
out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs)
if reduce_scatter_output:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec = tuple(
rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim)
)
reduce_scatter_spec = lhs_cspec
scatter_dim = out_specs.index(rhs_lspec)
elif all_reduce_output:
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
all_reduce_spec = lhs_cspec
else:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
lhs_specs = tuple(
None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i]
for i in range(lhs_ndim)
)
rhs_specs = tuple(
None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i]
for i in range(rhs_ndim)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if rhs_lspec is not None and rhs_lspec in tuple(
lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
# Bias and Pre-GeLU sharding is based on GEMM output
bias_specs = out_specs[len(lhs_non_contracting_specs) :]
gelu_specs = out_specs
return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
scatter_dim,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
mesh,
arg_infos,
result_infos,
):
del (
out_dtype,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
grad,
)
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
# Discard bias gradient spec if there is no bias fusion
if not fuse_bias:
dbias_specs = (None,)
dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs))
# Discard pre-GeLU output spec if there is no GeLU fusion
if not fuse_gelu:
pre_gelu_specs = (None,)
pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))
return [out_sharding, dbias_sharding, pre_gelu_sharding]
@staticmethod
def partition(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
mesh,
arg_infos,
result_infos,
):
del result_infos
(
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
scatter_dim,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
none_sharding = NamedSharding(mesh, PartitionSpec(None))
lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs))
rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs))
arg_shardings = (
lhs_sharding,
lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
rhs_sharding,
rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
)
# Discard bias input spec if there is no bias fusion
if not fuse_bias:
bias_input_specs = (None,)
arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),)
# Discard pre-GeLU input spec if there is no GeLU fusion
if not fuse_gelu:
gelu_input_specs = (None,)
arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),)
# Assemble output shardings
out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]
# Discard bias gradient spec if there is no bias fusion
if not fuse_bias:
dbias_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs)))
# Discard pre-GeLU output spec if there is no GeLU fusion
if not fuse_gelu:
pre_gelu_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input):
outputs = GemmPrimitive.impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
)
# All-Reduce/Reduce-Scatter GEMM output
if all_reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec)
elif reduce_scatter_spec is not None:
outputs[0] = jax.lax.psum_scatter(
outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum_scatter(
outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
return outputs
return mesh, _sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
mesh,
operand_types,
result_types,
):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
del mesh, result_types
prefix = "GemmPrimitive_"
def _generate_operand_rules(name, ndim, cdims, bdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in bdims + cdims)
for i in range(ndim):
dim_name = None
if i in bdims:
dim_idx = bdims.index(i) if len(bdims) > 1 else ""
dim_name = f"b{dim_idx}"
elif i in cdims:
dim_idx = cdims.index(i) if len(cdims) > 1 else ""
dim_name = f"k{dim_idx}"
else:
dim_idx = ldims.index(i) if len(ldims) > 1 else ""
dim_name = f"{name}_l{dim_idx}"
specs.append(prefix + dim_name)
return specs
lhs, _, rhs, *_ = operand_types
operand_ndims = (len(lhs.shape), len(rhs.shape))
(lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map(
lambda dims: map(sanitize_dims, operand_ndims, dims),
(contracting_dims, batched_dims),
)
lhs_specs, rhs_specs = map(
_generate_operand_rules,
("lhs", "rhs"),
operand_ndims,
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",)
if scaling_mode.is_1d_block_scaling():
# Shardy rules for MXFP8 scales cannot be related to the operands because of the
# global-unpadding and local-padding workflow. This can potentially insert expensive
# re-shards in the partition call later if the scales are not already sharded correctly.
lhs_scale_specs, rhs_scale_specs = map(
lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs),
(lhs_specs, rhs_specs),
)
lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
out_spec = (*lhs_non_cspec, *rhs_non_cspec)
bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
gelu_spec = out_spec if fuse_gelu else ("…5",)
return SdyShardingRule(
operand_mappings=(
lhs_specs,
lhs_scale_specs,
rhs_specs,
rhs_scale_specs,
bias_spec,
gelu_spec,
),
result_mappings=(
out_spec,
bias_spec,
gelu_spec,
),
)
register_primitive(GemmPrimitive)
def gemm_uses_jax_dot() -> bool:
"""Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot."""
return not GemmPrimitive.enabled()
def _te_gemm(
lhs: Union[jax.Array, ScaledTensor],
rhs: Union[jax.Array, ScaledTensor],
bias: jax.Array = None,
gelu_input: jax.Array = None,
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
lhs_data = lhs
rhs_data = rhs
lhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
rhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
# Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
# Extract GEMM custom op inputs from quantized operands
if isinstance(lhs_q, ScaledTensor):
assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, (
"cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid "
"`Quantizer` object to quantize the RHS operand."
)
if isinstance(lhs_q, ScaledTensor2x):
# Choose the quantization of the contracting dimension(s)
lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor()
scaling_mode = lhs_q.scaling_mode
lhs_data = lhs_q.data
lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
"cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid "
"`Quantizer` object to quantize the LHS operand."
)
if isinstance(rhs_q, ScaledTensor2x):
# Choose the quantization of the contracting dimension(s)
rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor()
assert rhs_q.scaling_mode == lhs_q.scaling_mode, (
"cuBLAS GEMM quantized operands have mismatched scaling types, "
f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
)
rhs_data = rhs_q.data
rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
# Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
if bias is None or not (fuse_bias and not grad):
bias = jnp.empty(0, dtype=out_dtype)
if gelu_input is None or not (fuse_gelu and grad):
gelu_input = jnp.empty(0, dtype=out_dtype)
return GemmPrimitive.outer_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
gelu_input,
out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims),
batched_dims=(lhs_bdims, rhs_bdims),
lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
)
class GroupedGemmPrimitive(BasePrimitive): class GroupedGemmPrimitive(BasePrimitive):
...@@ -102,15 +1029,28 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -102,15 +1029,28 @@ class GroupedGemmPrimitive(BasePrimitive):
A jnp.ndarray containing the result of the grouped GEMM operation A jnp.ndarray containing the result of the grouped GEMM operation
""" """
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias del K, lhs_is_trans, rhs_is_trans, has_bias
del lhs_scale_inv_aval, rhs_scale_inv_aval
# TODO(Phuong): move some shape checks from Cpp to here # TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
# JAX buffer pointers are 128-aligned workspace_alignment_padding = 256
# 255 is added to the workspace size to ensure workspace ptr is 256-aligned tensor_scaling_sinv_aligment = 16
workspace_size += 255 mxfp8_scaling_sinv_alignment_padding = 256
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size += workspace_alignment_padding
if scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
# For tensor scaling, each matrix has a single scale value, but it
# needs to be aligned to 16 bytes for CUDA 12.9.1 and later.
workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
# We also pad scale_inv swizzle buffers size for 256 bytes alignment.
workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
# TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue
out_shape = (M, N) out_shape = (M, N)
if is_grouped_dense_wgrad: if is_grouped_dense_wgrad:
...@@ -221,11 +1161,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) ...@@ -221,11 +1161,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False)
def _calculate_remaining_shape(shape, contracting_dims): def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) contracting_dims_ = sanitize_dims(len(shape), contracting_dims)
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_)
def _transpose_contract_dims(ndim, contracting_dims):
return tuple(ndim - i - 1 for i in contracting_dims)[::-1]
# Apply jit to guarantee correctness of FP8 GEMM. # Apply jit to guarantee correctness of FP8 GEMM.
...@@ -233,9 +1170,11 @@ def _transpose_contract_dims(ndim, contracting_dims): ...@@ -233,9 +1170,11 @@ def _transpose_contract_dims(ndim, contracting_dims):
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract) lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
if rhs.data_layout == "T": if rhs.data_layout == "T":
rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract) rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
...@@ -280,10 +1219,6 @@ def _jax_gemm_mxfp8_1d( ...@@ -280,10 +1219,6 @@ def _jax_gemm_mxfp8_1d(
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d = jnp.asarray(lhs_scale_3d[:, : lhs_3d.shape[1], : int(lhs_3d.shape[2] / 32)])
rhs_scale_3d = jnp.asarray(rhs_scale_3d[:, : rhs_3d.shape[1], : int(rhs_3d.shape[2] / 32)])
# JAX scaled_matmul only supports NT now (TN-gemm) # JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape: # * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_data (B, M, K) * rhs_data (B, N, K)
...@@ -306,12 +1241,12 @@ def _jax_gemm( ...@@ -306,12 +1241,12 @@ def _jax_gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
FP8 GEMM via JAX FP8 GEMM via JAX
""" """
dim_nums = (contracting_dims, ((), ())) dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
...@@ -331,32 +1266,16 @@ def _jax_gemm( ...@@ -331,32 +1266,16 @@ def _jax_gemm(
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
return _jax_gemm_fp8_impl(lhs, rhs)
if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor):
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): return _jax_gemm_fp8_impl(lhs_q, rhs_q)
if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel)
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q = quantizer_set.x.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = quantizer_set.kernel.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return _jax_gemm_fp8_impl(lhs_q, rhs_q)
if ( if (
isinstance(lhs, jnp.ndarray) isinstance(lhs, jnp.ndarray)
and isinstance(rhs, jnp.ndarray) and isinstance(rhs, jnp.ndarray)
and quantizer_set == noop_quantizer_set and lhs_quantizer is None
and rhs_quantizer is None
): ):
return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)
...@@ -366,30 +1285,109 @@ def _jax_gemm( ...@@ -366,30 +1285,109 @@ def _jax_gemm(
def gemm( def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
quantizer_set: QuantizerSet = noop_quantizer_set, batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
) -> jnp.ndarray: lhs_quantizer: Quantizer = None,
"""General matrix multiplication with optional quantization. rhs_quantizer: Quantizer = None,
**kwargs,
Args: ) -> Tuple[jnp.ndarray, ...]:
lhs: First input matrix. r"""General matrix multiplication with optional quantization.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions. Parameters
The first sequence represents the contracting dimensions of the first matrix, ----------
and the second sequence represents the contracting dimensions of the second matrix. lhs: Union[jax.Array, ScaledTensor]
quantizer_set: Set of quantizers for FP8 quantization of the output. Left-hand side operand in the matrix multiplication.
If None, no quantization is applied and the output has the same dtype as the inputs. rhs: Union[jax.Array, ScaledTensor]
Right-hand side operand in the matrix multiplication.
Returns: lhs_quantizer: Quantizer, default = None
If quantizer_set is None: Object for down-casting the LHS operand for quantized GEMM.
The matrix multiplication result. rhs_quantizer: Quantizer, default = None
Shape: (M, N) Object for down-casting the RHS operand for quantized GEMM.
Dtype: Same as input dtype contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
If quantizer_set is provided: Tuple of sequences representing the contracting dimensions of the operands.
A ScaledTensor containing the quantized matrix multiplication result. batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially
undesirable reduction in any batched contracting dimensions when invoked with sharded
operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM.
gelu_input: jax.Array, default = None
Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only
supported with TE's custom call to cuBLAS GEMM.
fuse_bias: bool, default = False
Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with
TE's custom call to cuBLAS GEMM.
fuse_gelu: bool, default = False
Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported
with TE's custom call to cuBLAS GEMM.
grad: bool, default = False
Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with
TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
Returns
-------
jax.Array:
Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the
GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution
when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and
`grad=False`.
Optional[jax.Array]:
Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call
to cuBLAS GEMM.
Optional[jax.Array]:
Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input
to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to
compute the GeLU contribution to the gradient. Only supported with TE's custom call to
cuBLAS GEMM.
""" """
# Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
if lhs_quantizer is None or rhs_quantizer is None:
quantizer_set = kwargs.get("quantizer_set", None)
if quantizer_set is not None:
lhs_quantizer = quantizer_set.x
rhs_quantizer = quantizer_set.kernel
# Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
fuse_bias = kwargs.get("fuse_bias", False)
fuse_gelu = kwargs.get("fuse_gelu", False)
if not GemmPrimitive.enabled():
assert kwargs.get("bias", None) is None and not fuse_gelu, (
"TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
"TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
outputs = _te_gemm(
lhs,
rhs,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
**kwargs,
)
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) # Discard empty outputs
grad = kwargs.get("grad", False)
clean_outputs = outputs[0] # first output is the final result and is never empty
if (fuse_bias and grad) or (fuse_gelu and not grad):
clean_outputs = (outputs[0],)
if fuse_bias and grad: # only return bias gradient if it exists
clean_outputs += (outputs[1],)
if fuse_gelu and not grad: # only return pre-GeLU output if it exists
clean_outputs += (outputs[2],)
return clean_outputs
def grouped_gemm( def grouped_gemm(
...@@ -490,15 +1488,13 @@ def grouped_gemm( ...@@ -490,15 +1488,13 @@ def grouped_gemm(
assert type(quantizer_set.x) is type(quantizer_set.kernel) assert type(quantizer_set.x) is type(quantizer_set.kernel)
scaling_mode = quantizer_set.x.scaling_mode scaling_mode = quantizer_set.x.scaling_mode
if ( if (
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later quantizer_set.x.scaling_mode.is_tensor_scaling()
# scaling_mode.is_tensor_scaling() and is_fp8_gemm_with_all_layouts_supported()
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
): ):
lhs_is_rowwise = rhs_is_rowwise = True lhs_is_rowwise = rhs_is_rowwise = True
else: else:
lhs_is_rowwise = not lhs_is_trans lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans rhs_is_rowwise = rhs_is_trans
quantizer_set.x.q_layout = ( quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
) )
...@@ -513,6 +1509,8 @@ def grouped_gemm( ...@@ -513,6 +1509,8 @@ def grouped_gemm(
rhs_data = rhs_q.data rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv
lhs_shape = lhs_q.original_shape
rhs_shape = rhs_q.original_shape
assert not ( assert not (
lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
...@@ -520,24 +1518,35 @@ def grouped_gemm( ...@@ -520,24 +1518,35 @@ def grouped_gemm(
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required # thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported():
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_is_trans = False
rhs_is_trans = True
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
lhs_layout_is_T = lhs.data_layout == "T" lhs_layout_is_T = lhs.data_layout == "T"
rhs_layout_is_T = rhs.data_layout == "T" rhs_layout_is_T = rhs.data_layout == "T"
else: else:
lhs_layout_is_T = lhs_q.data_layout == "T" lhs_layout_is_T = lhs_q.data_layout == "T"
rhs_layout_is_T = rhs_q.data_layout == "T" rhs_layout_is_T = rhs_q.data_layout == "T"
# we can't apply _shape_normalization on the grouped input
# thus we need to ensure that lhs is in N and rhs is in T
assert (
lhs_is_trans == lhs_layout_is_T
), "lhs input must be transposed before calling grouped_gemm"
assert (
not rhs_is_trans == rhs_layout_is_T
), "rhs input must be transposed before calling grouped_gemm"
lhs_is_trans = False
rhs_is_trans = True
lhs_ndim = len(lhs_shape) lhs_ndim = len(lhs_shape)
rhs_ndim = len(rhs_shape) rhs_ndim = len(rhs_shape)
if lhs_layout_is_T: if lhs_layout_is_T:
lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
if rhs_layout_is_T: if rhs_layout_is_T:
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) # For rhs [G, K, N], need to exclude the G dim from contract_dim
lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) if group_sizes.size == rhs_shape[0]:
rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T) rhs_contract_dim = tuple(
(rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim
)
else:
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
# Calling GroupedGEMM Custom Call # Calling GroupedGEMM Custom Call
K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
...@@ -557,9 +1566,6 @@ def grouped_gemm( ...@@ -557,9 +1566,6 @@ def grouped_gemm(
assert not has_bias or bias.shape == (group_sizes.size, N) assert not has_bias or bias.shape == (group_sizes.size, N)
bias = jnp.empty((), jnp.float32) if bias is None else bias bias = jnp.empty((), jnp.float32) if bias is None else bias
# TODO(Phuong): support MXFP8_1D_SCALING
assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported"
(out,) = GroupedGemmPrimitive.outer_primitive.bind( (out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs_data, lhs_data,
lhs_scale_inv, lhs_scale_inv,
......
...@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied. calculate dbias separately. This function checks if the workaround should be applied.
""" """
if quantizer is None:
return False
arch_l_100 = False arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())): for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True arch_l_100 = True
break break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
return ( return (
quantizer is not None (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100 and arch_l_100
and is_dbias and is_dbias
) )
......
...@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwdPrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1 len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",) out = x_axes
colwise_out = out if is_2x else ("…4",) colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1] rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)), (x_axes, ("…1",), ("…2",), ("…3",)),
...@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive):
mu, mu,
rsigma, rsigma,
), ),
**scale_rules.factor_sizes,
) )
...@@ -1276,6 +1276,7 @@ def normalization_fwd( ...@@ -1276,6 +1276,7 @@ def normalization_fwd(
epsilon: float, epsilon: float,
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1292,6 +1293,7 @@ def normalization_fwd( ...@@ -1292,6 +1293,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization - 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1319,6 +1321,15 @@ def normalization_fwd( ...@@ -1319,6 +1321,15 @@ def normalization_fwd(
else: else:
raise ValueError(f"{norm_type=} is not supported.") raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype
),
mu,
rsigma,
)
return output, mu, rsigma return output, mu, rsigma
......
...@@ -36,7 +36,6 @@ from ..quantize import ( ...@@ -36,7 +36,6 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeLayout, QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
) )
...@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
): ):
del out_dtype, scale_dtype, is_outer, mesh, result_types del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), len(value_types[0].shape),
unique_var="BaseDBiasQuantizePrimitive_i", unique_var=prefix + "x",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv = scale_rules.colwise_rule colwise_scale_inv = scale_rules.colwise_rule
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling(): if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else: else:
colwise_out = x_axes colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",) dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = ("m",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",)), (x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -538,11 +535,12 @@ def _jax_quantize( ...@@ -538,11 +535,12 @@ def _jax_quantize(
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0 sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype dtype = dtype or dx.dtype
dbias = jnp.sum( dbias = jnp.sum(
dx.astype(jnp.float32), dx.astype(jnp.float32),
axis=tuple(range(dx.ndim + flatten_axis)), axis=tuple(range(sum_axis)),
keepdims=False, keepdims=False,
) )
return dbias.astype(dtype) return dbias.astype(dtype)
...@@ -568,6 +566,7 @@ def _quantize_dbias_impl( ...@@ -568,6 +566,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -577,24 +576,34 @@ def _quantize_dbias_impl( ...@@ -577,24 +576,34 @@ def _quantize_dbias_impl(
quantizer is not None quantizer is not None
), "quantizer must be provided if dq_dtype is provided" ), "quantizer must be provided if dq_dtype is provided"
# Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype dq_dtype = dq_dtype or x.dtype
if quantizer is None:
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive dbias = None
if not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
x, if noop_scaled_tensor:
quantizer=quantizer, # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
dq_dtype=dq_dtype, # always works.
flatten_axis=flatten_axis, return (
ScaledTensorFactory.create_2x(
x,
None,
x,
None,
ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
) )
return ( return x, dbias
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
# TE/common doesn't support colwise only quantization yet # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
...@@ -606,9 +615,8 @@ def _quantize_dbias_impl( ...@@ -606,9 +615,8 @@ def _quantize_dbias_impl(
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None, None,
) )
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100 # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
x=x, x=x,
...@@ -620,29 +628,23 @@ def _quantize_dbias_impl( ...@@ -620,29 +628,23 @@ def _quantize_dbias_impl(
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
if quantizer is None: scale = jnp.empty((), jnp.float32)
if is_dbias:
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale. # Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM). # until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
# It is faster to use 1x quantization for tensor scaling # It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = ( force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x() and quantizer.is_2x2x()
and is_1x_kernel_supported and is_1x_kernel_supported
) )
q_layout = quantizer.q_layout q_layout = quantizer.q_layout
if force_1x_quantization: if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE q_layout = QuantizeLayout.ROWWISE
...@@ -698,6 +700,7 @@ def quantize( ...@@ -698,6 +700,7 @@ def quantize(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -707,6 +710,8 @@ def quantize( ...@@ -707,6 +710,8 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
...@@ -715,6 +720,7 @@ def quantize( ...@@ -715,6 +720,7 @@ def quantize(
x, x,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
return out return out
...@@ -724,6 +730,7 @@ def quantize_dbias( ...@@ -724,6 +730,7 @@ def quantize_dbias(
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -734,6 +741,8 @@ def quantize_dbias( ...@@ -734,6 +741,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -743,7 +752,11 @@ def quantize_dbias( ...@@ -743,7 +752,11 @@ def quantize_dbias(
Shape: (K,) or empty if is_dbias is False. Shape: (K,) or empty if is_dbias is False.
""" """
return _quantize_dbias_impl( return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis dz,
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
......
...@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
// Grouped GEMM // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
......
...@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { ...@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN: case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
break; break;
// case xla::ffi::DataType::F8E8M0FNU: case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0; return DType::kFloat8E8M0;
// break; break;
default: default:
auto type_num = static_cast<XLA_FFI_DataType>(type); auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num)); static_cast<int>(type_num));
break; break;
......
...@@ -6,10 +6,14 @@ ...@@ -6,10 +6,14 @@
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include <memory> #include <memory>
#include <string_view>
#include <tuple>
#include "../extensions.h" #include "../extensions.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/string.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32 #define MXFP8_BLOCK_SIZE 32
...@@ -17,6 +21,187 @@ ...@@ -17,6 +21,187 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
// Move the pointer to the next 256B aligned address
return reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(ptr) + 255) &
~static_cast<uintptr_t>(255));
}
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
product(buffer_dims, axis_boundary, buffer_dims.size())};
auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type());
TensorWrapper input(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
} else {
input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
}
// Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
}
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
}
return std::make_tuple(std::move(input), input_shape);
}
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0};
DType bias_dtype = out_dtype;
if (fuse_bias) {
if (!grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
}
bias_ptr = bias_grad->untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type());
}
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr;
std::vector<size_t> pre_gelu_shape = {0};
DType pre_gelu_dtype = out_dtype;
if (gelu_input.element_count() > 0) {
if (grad) {
NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out");
}
pre_gelu_ptr = pre_gelu_out->untyped_data();
pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1),
static_cast<size_t>(pre_gelu_out->dimensions().back())};
pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type());
}
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"),
FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
...@@ -54,15 +239,43 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -54,15 +239,43 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
NVTE_CHECK(group_sizes.dimensions().size() == 1); NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0]; size_t num_gemms = group_sizes.dimensions()[0];
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
// Outputs // Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data()); auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
// Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto workspace_ptr = auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(workspace->untyped_data()) + 255) & workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
~static_cast<uintptr_t>(255)); auto workspace_total_size = product(workspace->dimensions());
auto workspace_total_size = product(workspace->dimensions()) - 255;
auto workspace_size = workspace_total_size / num_streams; auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
const size_t workspace_alignment_padding = 256;
const size_t tensor_scaling_sinv_aligment = 16;
const size_t mxfp8_scaling_sinv_alignment_padding = 256;
auto workspace_size = workspace_total_size - workspace_alignment_padding;
if (is_mxfp8_scaling) {
// For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4.
workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding);
} else if (is_tensor_scaling) {
// For tensor scaling, each matrix has a single scale value, and all scales need to be aligned
// by 16 bytes to meet the requirement of CUDA 12.9.1 and later.
workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size);
}
workspace_size = workspace_size / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr);
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr);
auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned
auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
...@@ -71,6 +284,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -71,6 +284,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
if (is_tensor_scaling) {
cudaStream_t stream_0 = nvte_get_compute_stream(0);
size_t dpitch = tensor_scaling_sinv_aligment;
size_t spitch = lhs_sinv_dtype_bytes;
size_t width = lhs_sinv_dtype_bytes;
size_t height = lhs_sinv_size;
cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0);
spitch = rhs_sinv_dtype_bytes;
width = rhs_sinv_dtype_bytes;
height = rhs_sinv_size;
cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0);
lhs_sinv_ptr = lhs_scatter_aligned_ptr;
rhs_sinv_ptr = rhs_scatter_aligned_ptr;
}
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
...@@ -120,12 +350,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -120,12 +350,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto bias_shape = std::vector<size_t>{has_bias ? n : 0}; auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch(); const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) { if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans, NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
...@@ -135,6 +359,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -135,6 +359,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// These lists are to keep the TensorWrapper objects alive // These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list; std::vector<TensorWrapper> lhs_wrapper_list;
std::vector<TensorWrapper> rhs_wrapper_list; std::vector<TensorWrapper> rhs_wrapper_list;
std::vector<TensorWrapper> lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling
std::vector<TensorWrapper> rhs_swizzle_wrapper_list;
std::vector<TensorWrapper> bias_wrapper_list; std::vector<TensorWrapper> bias_wrapper_list;
std::vector<TensorWrapper> pre_gelu_wrapper_list; std::vector<TensorWrapper> pre_gelu_wrapper_list;
std::vector<TensorWrapper> out_wrapper_list; std::vector<TensorWrapper> out_wrapper_list;
...@@ -143,66 +369,119 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -143,66 +369,119 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std::vector<NVTETensor> lhs_list; std::vector<NVTETensor> lhs_list;
std::vector<NVTETensor> rhs_list; std::vector<NVTETensor> rhs_list;
std::vector<NVTETensor> lhs_swizzle_list;
std::vector<NVTETensor> rhs_swizzle_list;
std::vector<NVTETensor> bias_list; std::vector<NVTETensor> bias_list;
std::vector<NVTETensor> pre_gelu_list; std::vector<NVTETensor> pre_gelu_list;
std::vector<NVTETensor> out_list; std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list; std::vector<NVTETensor> workspace_list;
size_t lhs_sinv_total_size = 0;
size_t rhs_sinv_total_size = 0;
std::vector<void *> zero_out_dptr_list;
std::vector<size_t> zero_out_size_list;
for (size_t i = 0; i < num_gemms; i++) { for (size_t i = 0; i < num_gemms; i++) {
// Matrix data shapes // Matrix data shapes
size_t m_i = dim_list_host[i]; size_t m_i = dim_list_host[i];
auto lhs_shape = std::vector<size_t>{m_i, k}; auto lhs_shape_i = std::vector<size_t>{m_i, k};
auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto rhs_shape_i = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n};
auto out_shape = std::vector<size_t>{m_i, n}; auto out_shape_i = std::vector<size_t>{m_i, n};
if (is_grouped_dense_wgrad) { if (is_grouped_dense_wgrad) {
size_t k_i = dim_list_host[i]; size_t k_i = dim_list_host[i];
lhs_shape[0] = lhs_is_trans ? k_i : m; lhs_shape_i[0] = lhs_is_trans ? k_i : m;
lhs_shape[1] = lhs_is_trans ? m : k_i; lhs_shape_i[1] = lhs_is_trans ? m : k_i;
rhs_shape[0] = rhs_is_trans ? n : k_i; rhs_shape_i[0] = rhs_is_trans ? n : k_i;
rhs_shape[1] = rhs_is_trans ? k_i : n; rhs_shape_i[1] = rhs_is_trans ? k_i : n;
out_shape[0] = m; out_shape_i[0] = m;
out_shape[1] = n; out_shape_i[1] = n;
}
size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1];
size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1];
size_t out_size = out_shape_i[0] * out_shape_i[1];
bool is_empty_gemm = lhs_size == 0 || rhs_size == 0;
if (is_empty_gemm && out_size > 0) {
zero_out_dptr_list.push_back(out_ptr);
zero_out_size_list.push_back(out_size * out_dtype_bytes);
} }
// Set matrix data pointers // Set matrix data pointers
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype); auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape_i, out_dtype);
void *lhs_vptr = static_cast<void *>(lhs_ptr); void *lhs_vptr = static_cast<void *>(lhs_ptr);
void *rhs_vptr = static_cast<void *>(rhs_ptr); void *rhs_vptr = static_cast<void *>(rhs_ptr);
if (rhs_use_colwise) // MatA to enter cuBLAS if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
else else
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
if (lhs_use_colwise) // MatB to enter cuBLAS if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
else else
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape); lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
// Scale_inv shapes
auto lhs_sinv_size = std::vector<size_t>{1};
auto rhs_sinv_size = std::vector<size_t>{1};
if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t scale_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_size[0] = m_i * scale_k;
rhs_sinv_size[0] = n * scale_k;
// Need to add swizzle here
}
// Set scale_inv pointers // Set scale_inv shapes and pointers
void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr); void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr);
void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr); void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr);
if (is_fp8_gemm) { size_t lhs_sinv_size_i = 0;
size_t rhs_sinv_size_i = 0;
if (is_tensor_scaling) {
auto tensor_scaling_sinv_shape = std::vector<size_t>{1};
// If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers
if (!is_empty_gemm) {
lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes;
rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes;
}
if (rhs_use_colwise) // MatA to enter cuBLAS if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape);
else else
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape);
if (lhs_use_colwise) // MatB to enter cuBLAS if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape);
else else
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape);
} else if (is_mxfp8_scaling) {
auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
void *swizzled_lhs_sinv_vptr = static_cast<void *>(swizzled_lhs_sinv_ptr);
void *swizzled_rhs_sinv_vptr = static_cast<void *>(swizzled_rhs_sinv_ptr);
// {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i
// point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto lhs_sinv_shape_i =
get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
auto rhs_sinv_shape_i =
get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1];
rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1];
if (lhs_use_colwise) {
lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
} else {
lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
}
if (rhs_use_colwise) {
rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
} else {
rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
}
if (!is_empty_gemm) {
lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i));
rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i));
lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data());
rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data());
}
} else { } else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode)); "Unsupported scaling mode: ", static_cast<int>(scaling_mode));
...@@ -212,16 +491,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -212,16 +491,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype);
// Update pointer for the next GEMM pair // Update pointer for the next GEMM pair
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; lhs_ptr += lhs_size * lhs_dtype_bytes;
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; rhs_ptr += rhs_size * rhs_dtype_bytes;
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; out_ptr += out_size * out_dtype_bytes;
if (is_fp8_gemm) { if (is_fp8_gemm) {
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes;
lhs_sinv_total_size += lhs_sinv_size_i;
rhs_sinv_total_size += rhs_sinv_size_i;
if (is_mxfp8_scaling) {
swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes;
swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes;
}
} }
if (has_bias) bias_ptr += n * bias_dtype_bytes; if (has_bias) bias_ptr += n * bias_dtype_bytes;
// Move objects to the lists to keep them alive // Move objects to the lists to keep them alive
if (is_empty_gemm) continue;
lhs_wrapper_list.push_back(std::move(lhs_i)); lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i)); rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i)); out_wrapper_list.push_back(std::move(out_i));
...@@ -244,10 +530,45 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -244,10 +530,45 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
workspace_ptr += workspace_size; workspace_ptr += workspace_size;
} }
if (is_fp8_gemm) {
if (is_tensor_scaling) {
lhs_sinv_size *= tensor_scaling_sinv_aligment;
rhs_sinv_size *= tensor_scaling_sinv_aligment;
}
NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ",
lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size);
NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ",
rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size);
}
size_t num_non_empty_gemms = lhs_list.size();
if (is_mxfp8_scaling) {
for (int i = 0; i < num_non_empty_gemms; i++) {
// The i-th GEMM will use the (i % num_streams)-th stream to compute,
// use the same stream to swizzle the scaling factors to make sure that
// the swizzling is done before the GEMM computation starts.
int stream_id = i % num_streams;
cudaStream_t stream_i = nvte_get_compute_stream(stream_id);
nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i);
nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i);
}
}
// Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM
size_t num_zero_outs = zero_out_dptr_list.size();
for (int i = 0; i < num_zero_outs; i++) {
int stream_id = i % num_streams;
cudaStream_t stream_i = nvte_get_compute_stream(stream_id);
void *dptr = zero_out_dptr_list[i];
size_t count = zero_out_size_list[i];
NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i));
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans,
workspace_list.data(), accumulate, use_split_accumulator, lhs_is_trans, grad, workspace_list.data(), accumulate,
num_math_sm, stream); use_split_accumulator, num_math_sm, stream);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
......
...@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { ...@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t {
CURRENT_TENSOR_SCALING = 3, CURRENT_TENSOR_SCALING = 3,
}; };
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}
inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) { switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING: case JAXX_Scaling_Mode::NO_SCALING:
......
...@@ -55,6 +55,11 @@ pybind11::dict Registrations() { ...@@ -55,6 +55,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
dict["te_grouped_gemm_ffi"] = dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
...@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment