Unverified Commit 237b4930 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward (#1319)



* FFI for all softmax functions
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* FFI for FusedAttnBackward and Dequantize

FusedAttnBackward passed all testes in test_fused_attn.py.
Dequantize is not used currently; finish it for completeness.
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Fix FusedAttnBackward FFI pybind & simplify
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert changes to tests/jax/test_fused_attn.py
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 68adf451
...@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled(): if is_ffi_enabled():
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)( out = ffi.ffi_lowering(name)(
...@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, input_batch,
bias_batch, bias_batch,
...@@ -725,28 +725,6 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -725,28 +725,6 @@ class FusedAttnBwdPrimitive(BasePrimitive):
""" """
Fused attention bwd lowering rules Fused attention bwd lowering rules
""" """
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
...@@ -761,33 +739,90 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -761,33 +739,90 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1] if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, input_batch,
bias_batch, bias_batch,
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
attn_heads, attn_heads,
num_gqa_groups, num_gqa_groups,
bias_heads, bias_heads,
head_dim, head_dim,
config.max_segments_per_seq, config.max_segments_per_seq,
wkspace_aval.size, wkspace_aval.size,
config.scaling_factor, config.scaling_factor,
config.dropout_probability, config.dropout_probability,
config.attn_bias_type, config.attn_bias_type,
config.attn_mask_type, config.attn_mask_type,
config.qkv_layout, config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training, config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(), not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0], config.window_size[0],
config.window_size[1], config.window_size[1],
) )
out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
return out return out
......
...@@ -12,12 +12,13 @@ import jax.numpy as jnp ...@@ -12,12 +12,13 @@ import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
...@@ -133,32 +134,36 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -133,32 +134,36 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_forward lowering rules softmax_forward lowering rules
""" """
(i_aval,) = ctx.avals_in if is_ffi_enabled():
i_type = ir.RankedTensorType(logits.type) ffi_name = name + "_ffi"
i_shape = i_type.shape out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen] else:
batch = reduce(operator.mul, i_shape[:-3]) (i_aval,) = ctx.avals_in
pad_batch = batch i_type = ir.RankedTensorType(logits.type)
heads = i_shape[-3] i_shape = i_type.shape
q_seqlen = i_shape[-2] # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
k_seqlen = i_shape[-1] batch = reduce(operator.mul, i_shape[:-3])
pad_batch = batch
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] heads = i_shape[-3]
operands = [logits] q_seqlen = i_shape[-2]
operand_shapes = [i_shape] k_seqlen = i_shape[-1]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
opaque = transformer_engine_jax.pack_softmax_descriptor( operands = [logits]
batch, operand_shapes = [i_shape]
pad_batch, args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
heads,
q_seqlen, opaque = transformer_engine_jax.pack_softmax_descriptor(
k_seqlen, batch,
jax_dtype_to_te_dtype(i_aval.dtype), pad_batch,
scale_factor, heads,
) q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor,
)
out = custom_caller(name, args, opaque, False) out = custom_caller(name, args, opaque, False)
return out return out
...@@ -240,37 +245,41 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -240,37 +245,41 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_backward lowering rules softmax_backward lowering rules
""" """
dz_aval, _ = ctx.avals_in if is_ffi_enabled():
ffi_name = name + "_ffi"
dz_type = ir.RankedTensorType(dz.type) out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
dz_shape = dz_type.shape else:
dz_aval, _ = ctx.avals_in
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, dz_shape[:-3]) dz_type = ir.RankedTensorType(dz.type)
pad_batch = batch # unused dz_shape = dz_type.shape
heads = dz_shape[-3]
q_seqlen = dz_shape[-2] # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
k_seqlen = dz_shape[-1] batch = reduce(operator.mul, dz_shape[:-3])
pad_batch = batch # unused
softmax_out_type = ir.RankedTensorType(softmax_out.type) heads = dz_shape[-3]
softmax_out_shape = softmax_out_type.shape q_seqlen = dz_shape[-2]
k_seqlen = dz_shape[-1]
out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out] softmax_out_type = ir.RankedTensorType(softmax_out.type)
operand_shapes = [dz_shape, softmax_out_shape] softmax_out_shape = softmax_out_type.shape
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
opaque = transformer_engine_jax.pack_softmax_descriptor( operands = [dz, softmax_out]
batch, operand_shapes = [dz_shape, softmax_out_shape]
pad_batch, args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
heads,
q_seqlen, opaque = transformer_engine_jax.pack_softmax_descriptor(
k_seqlen, batch,
jax_dtype_to_te_dtype(dz_aval.dtype), pad_batch,
scale_factor, heads,
) q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(dz_aval.dtype),
scale_factor,
)
out = custom_caller(name, args, opaque, False) out = custom_caller(name, args, opaque, False)
return out return out
...@@ -577,36 +586,39 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -577,36 +586,39 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_masked_softmax_forward lowering rules te_scaled_masked_softmax_forward lowering rules
""" """
if is_ffi_enabled():
ffi_name = "te_scaled_masked_softmax_forward_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]
mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
pad_batch = reduce(operator.mul, mask_shape[:-3])
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(logits_aval.dtype),
scale_factor,
)
logits_aval, _ = ctx.avals_in out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch = reduce(operator.mul, i_shape[:-3])
heads = i_shape[-3]
q_seqlen = i_shape[-2]
k_seqlen = i_shape[-1]
mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
pad_batch = reduce(operator.mul, mask_shape[:-3])
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [logits, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch,
pad_batch,
heads,
q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(logits_aval.dtype),
scale_factor,
)
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return out return out
......
...@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler); ...@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
// Softmax // Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
...@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Attention // Attention
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
...@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -68,6 +68,19 @@ pybind11::dict Registrations() { ...@@ -68,6 +68,19 @@ pybind11::dict Registrations() {
// Quantization // Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization // Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
...@@ -83,6 +96,11 @@ pybind11::dict Registrations() { ...@@ -83,6 +96,11 @@ pybind11::dict Registrations() {
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler); fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi; dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;
pybind11::dict fused_attn_backward_ffi;
fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler);
dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi;
return dict; return dict;
} }
......
...@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t
auto shape = desc.shape.to_vector(); auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
} }
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end());
auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>(), // output
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/softmax.h" #include "transformer_engine/softmax.h"
#include "extensions.h" #include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
dgrad_tensor.data(), desc.scale_factor, stream); dgrad_tensor.data(), desc.scale_factor, stream);
} }
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto tensor_dims = (tensor_buf).dimensions(); \
auto tensor_ranks = tensor_dims.size(); \
auto batch_size = product(tensor_dims, 0, tensor_ranks - 3); \
auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2); \
auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1); \
auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks); \
float scale_factor = static_cast<float>(scale_factor_);
#define SOFTMAX_FORWARD_COMMON_BLOCK \
auto *input = input_buf.untyped_data(); \
auto *output = output_buf->untyped_data(); \
auto input_tensor = TensorWrapper(input, shape, dtype); \
auto output_tensor = TensorWrapper(output, shape, dtype);
Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type mask_buf, Result_Type output_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
// Mask would be casted to uint8_t
auto *mask = mask_buf.untyped_data();
auto mask_dims = mask_buf.dimensions();
auto padding_size = product(mask_dims, mask_dims.size() - 3);
auto mask_shape = std::vector<size_t>{padding_size, 1, q_seqlen, k_seqlen};
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // mask
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler,
ScaledUpperTriangMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
#define SOFTMAX_BACKWARD_COMMON_BLOCK \
auto *grad_output = grad_output_buf.untyped_data(); \
auto *softmax_output = softmax_output_buf.untyped_data(); \
auto *dgrad = dgrad_buf->untyped_data(); \
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); \
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf, Result_Type dgrad_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream,
Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf,
Result_Type dgrad_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler,
ScaledUpperTriangMaskedSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment