Unverified Commit 237b4930 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward (#1319)



* FFI for all softmax functions
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* FFI for FusedAttnBackward and Dequantize

FusedAttnBackward passed all testes in test_fused_attn.py.
Dequantize is not used currently; finish it for completeness.
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Fix FusedAttnBackward FFI pybind & simplify
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert changes to tests/jax/test_fused_attn.py
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 68adf451
......@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
......@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
......@@ -725,6 +725,56 @@ class FusedAttnBwdPrimitive(BasePrimitive):
"""
Fused attention bwd lowering rules
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
......@@ -744,23 +794,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
......
......@@ -12,12 +12,13 @@ import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType
......@@ -133,6 +134,10 @@ class SoftmaxPrimitive(BasePrimitive):
"""
softmax_forward lowering rules
"""
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
else:
(i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
......@@ -240,6 +245,10 @@ class SoftmaxPrimitive(BasePrimitive):
"""
softmax_backward lowering rules
"""
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
else:
dz_aval, _ = ctx.avals_in
dz_type = ir.RankedTensorType(dz.type)
......@@ -577,7 +586,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_masked_softmax_forward lowering rules
"""
if is_ffi_enabled():
ffi_name = "te_scaled_masked_softmax_forward_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
......
......@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
// Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
......@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Attention
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
......@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
} // namespace jax
} // namespace transformer_engine
......
......@@ -68,6 +68,19 @@ pybind11::dict Registrations() {
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
......@@ -83,6 +96,11 @@ pybind11::dict Registrations() {
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;
pybind11::dict fused_attn_backward_ffi;
fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler);
dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi;
return dict;
}
......
......@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end());
auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>(), // output
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
......@@ -7,6 +7,7 @@
#include "transformer_engine/softmax.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
......@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
dgrad_tensor.data(), desc.scale_factor, stream);
}
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto tensor_dims = (tensor_buf).dimensions(); \
auto tensor_ranks = tensor_dims.size(); \
auto batch_size = product(tensor_dims, 0, tensor_ranks - 3); \
auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2); \
auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1); \
auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks); \
float scale_factor = static_cast<float>(scale_factor_);
#define SOFTMAX_FORWARD_COMMON_BLOCK \
auto *input = input_buf.untyped_data(); \
auto *output = output_buf->untyped_data(); \
auto input_tensor = TensorWrapper(input, shape, dtype); \
auto output_tensor = TensorWrapper(output, shape, dtype);
Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type mask_buf, Result_Type output_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
// Mask would be casted to uint8_t
auto *mask = mask_buf.untyped_data();
auto mask_dims = mask_buf.dimensions();
auto padding_size = product(mask_dims, mask_dims.size() - 3);
auto mask_shape = std::vector<size_t>{padding_size, 1, q_seqlen, k_seqlen};
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // mask
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler,
ScaledUpperTriangMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
#define SOFTMAX_BACKWARD_COMMON_BLOCK \
auto *grad_output = grad_output_buf.untyped_data(); \
auto *softmax_output = softmax_output_buf.untyped_data(); \
auto *dgrad = dgrad_buf->untyped_data(); \
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); \
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf, Result_Type dgrad_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream,
Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf,
Result_Type dgrad_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler,
ScaledUpperTriangMaskedSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
} // namespace jax
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment