Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
...@@ -301,39 +301,6 @@ static void FusedAttnForwardImpl( ...@@ -301,39 +301,6 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *seed = buffers[4];
void *q_cu_seqlens = buffers[5];
void *kv_cu_seqlens = buffers[6];
void *q_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[8] : nullptr;
/* Output buffer from XLA */
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];
FusedAttnForwardImpl(
stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets,
output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch,
descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads,
descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
#define FUSED_ATTN_FFI_GET_ATTRS \ #define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \ size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \ size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
...@@ -608,45 +575,6 @@ static void FusedAttnBackwardImpl( ...@@ -608,45 +575,6 @@ static void FusedAttnBackwardImpl(
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */
void *dq = buffers[12];
void *dk = buffers[13];
void *dv = buffers[14];
void *dbias = buffers[15];
void *workspace = buffers[16];
FusedAttnBackwardImpl(
stream, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlens, kv_cu_seqlens,
q_seq_offsets, k_seq_offsets, dq, dk, dv, dbias, workspace, descriptor.input_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type CublasHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cublas_handle_init();
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasHandleInitHandler, CublasHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
...@@ -13,8 +13,9 @@ namespace jax { ...@@ -13,8 +13,9 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 // For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) { switch (type) {
// Using this for E8M0
case xla::ffi::DataType::U8: case xla::ffi::DataType::U8:
return DType::kByte; return DType::kFloat8E8M0;
break; break;
case xla::ffi::DataType::S32: case xla::ffi::DataType::S32:
return DType::kInt32; return DType::kInt32;
...@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { ...@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN: case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
break; break;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
default: default:
auto type_num = static_cast<XLA_FFI_DataType>(type); auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num)); static_cast<int>(type_num));
break; break;
......
...@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id ...@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id
std::multiplies<size_t>()); std::multiplies<size_t>());
} }
inline static size_t te_dtype_bytes(const DType& type) {
switch (type) {
case DType::kByte:
return 1;
case DType::kInt32:
return 4;
case DType::kInt64:
return 8;
case DType::kFloat32:
return 4;
case DType::kFloat16:
return 2;
case DType::kBFloat16:
return 2;
case DType::kFloat8E5M2:
return 1;
case DType::kFloat8E4M3:
return 1;
case DType::kFloat8E8M0:
return 1;
default:
NVTE_ERROR("Unsupported DType: ", static_cast<int>(type));
}
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
This diff is collapsed.
...@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeAxis {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase from .module import LayerNormDenseGeneral, LayerNormMLP
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType from .transformer import TransformerLayer, TransformerLayerType
...@@ -13,7 +13,6 @@ __all__ = [ ...@@ -13,7 +13,6 @@ __all__ = [
"LayerNorm", "LayerNorm",
"LayerNormDenseGeneral", "LayerNormDenseGeneral",
"LayerNormMLP", "LayerNormMLP",
"TransformerEngineBase",
"extend_logical_axis_rules", "extend_logical_axis_rules",
"DotProductAttention", "DotProductAttention",
"MultiHeadAttention", "MultiHeadAttention",
......
This diff is collapsed.
...@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else: else:
assert qkv_layout.is_separate() assert qkv_layout.is_separate()
assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray) assert sequence_descriptor is None or isinstance(
sequence_descriptor, (jnp.ndarray, np.ndarray)
)
x = _UnfusedDotProductAttention( x = _UnfusedDotProductAttention(
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
...@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
...@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
input_dtype = inputs.dtype input_dtype = inputs.dtype
assert ( assert (
self.layer_type in TransformerLayerType self.layer_type in TransformerLayerType
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment