Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
......@@ -301,39 +301,6 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_destroy(&aux_output_tensors);
}
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *seed = buffers[4];
void *q_cu_seqlens = buffers[5];
void *kv_cu_seqlens = buffers[6];
void *q_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[8] : nullptr;
/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];
FusedAttnForwardImpl(
stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets,
output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch,
descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads,
descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
#define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
......@@ -608,45 +575,6 @@ static void FusedAttnBackwardImpl(
nvte_tensor_pack_destroy(&aux_input_tensors);
}
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */
void *dq = buffers[12];
void *dk = buffers[13];
void *dv = buffers[14];
void *dbias = buffers[15];
void *workspace = buffers[16];
FusedAttnBackwardImpl(
stream, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlens, kv_cu_seqlens,
q_seq_offsets, k_seq_offsets, dq, dk, dv, dbias, workspace, descriptor.input_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type CublasHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cublas_handle_init();
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasHandleInitHandler, CublasHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
......@@ -13,8 +13,9 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) {
// Using this for E8M0
case xla::ffi::DataType::U8:
return DType::kByte;
return DType::kFloat8E8M0;
break;
case xla::ffi::DataType::S32:
return DType::kInt32;
......@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3;
break;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num));
break;
......
......@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id
std::multiplies<size_t>());
}
inline static size_t te_dtype_bytes(const DType& type) {
switch (type) {
case DType::kByte:
return 1;
case DType::kInt32:
return 4;
case DType::kInt64:
return 8;
case DType::kFloat32:
return 4;
case DType::kFloat16:
return 2;
case DType::kBFloat16:
return 2;
case DType::kFloat8E5M2:
return 1;
case DType::kFloat8E4M3:
return 1;
case DType::kFloat8E8M0:
return 1;
default:
NVTE_ERROR("Unsupported DType: ", static_cast<int>(type));
}
}
} // namespace jax
} // namespace transformer_engine
This diff is collapsed.
......@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeAxis {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
} // namespace jax
} // namespace transformer_engine
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .module import LayerNormDenseGeneral, LayerNormMLP
from .transformer import extend_logical_axis_rules
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
......@@ -13,7 +13,6 @@ __all__ = [
"LayerNorm",
"LayerNormDenseGeneral",
"LayerNormMLP",
"TransformerEngineBase",
"extend_logical_axis_rules",
"DotProductAttention",
"MultiHeadAttention",
......
This diff is collapsed.
......@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert qkv_layout.is_separate()
assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray)
assert sequence_descriptor is None or isinstance(
sequence_descriptor, (jnp.ndarray, np.ndarray)
)
x = _UnfusedDotProductAttention(
attention_dropout=self.attention_dropout,
......@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
The data type used to allocate the initial parameters.
fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
......@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
Output tensors.
"""
input_dtype = inputs.dtype
assert (
self.layer_type in TransformerLayerType
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment