Unverified Commit 237b4930 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward (#1319)



* FFI for all softmax functions
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* FFI for FusedAttnBackward and Dequantize

FusedAttnBackward passed all testes in test_fused_attn.py.
Dequantize is not used currently; finish it for completeness.
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Fix FusedAttnBackward FFI pybind & simplify
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert changes to tests/jax/test_fused_attn.py
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 68adf451
...@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled(): if is_ffi_enabled():
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)( out = ffi.ffi_lowering(name)(
...@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, input_batch,
bias_batch, bias_batch,
...@@ -725,6 +725,56 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -725,6 +725,56 @@ class FusedAttnBwdPrimitive(BasePrimitive):
""" """
Fused attention bwd lowering rules Fused attention bwd lowering rules
""" """
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [ operands = [
q, q,
k, k,
...@@ -744,23 +794,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -744,23 +794,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
......
...@@ -12,12 +12,13 @@ import jax.numpy as jnp ...@@ -12,12 +12,13 @@ import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
...@@ -133,6 +134,10 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -133,6 +134,10 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_forward lowering rules softmax_forward lowering rules
""" """
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
else:
(i_aval,) = ctx.avals_in (i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type) i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape i_shape = i_type.shape
...@@ -240,6 +245,10 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -240,6 +245,10 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_backward lowering rules softmax_backward lowering rules
""" """
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
else:
dz_aval, _ = ctx.avals_in dz_aval, _ = ctx.avals_in
dz_type = ir.RankedTensorType(dz.type) dz_type = ir.RankedTensorType(dz.type)
...@@ -577,7 +586,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -577,7 +586,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_masked_softmax_forward lowering rules te_scaled_masked_softmax_forward lowering rules
""" """
if is_ffi_enabled():
ffi_name = "te_scaled_masked_softmax_forward_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type) i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape i_shape = i_type.shape
......
...@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler); ...@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
// Softmax // Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
...@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Attention // Attention
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
...@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -185,6 +185,33 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -185,6 +185,33 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
} }
#define FUSED_ATTN_IMPL_COMMON_BLOCK \
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \
size_t num_segments = input_batch; \
if (is_ragged) { \
auto cudnn_runtime_version = cudnnGetVersion(); \
if (cudnn_runtime_version >= 90300) { \
num_segments = input_batch * max_segments_per_seq; \
} else { \
size_t runtime_num_segments_q = \
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv = \
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \
num_segments = runtime_num_segments_q; \
} \
} \
std::vector<size_t> seq_shape{num_segments + 1}; \
auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \
auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32); \
auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32); \
auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32); \
auto workspace_tensor = \
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype); \
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static void FusedAttnForwardImpl( static void FusedAttnForwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output,
...@@ -194,43 +221,16 @@ static void FusedAttnForwardImpl( ...@@ -194,43 +221,16 @@ static void FusedAttnForwardImpl(
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) { if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
} }
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -251,12 +251,7 @@ static void FusedAttnForwardImpl( ...@@ -251,12 +251,7 @@ static void FusedAttnForwardImpl(
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type,
backend, softmax_aux); backend, softmax_aux);
/* cuDNN workspace */
auto workspace_tensor =
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype);
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
...@@ -304,7 +299,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -304,7 +299,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD; auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */ /* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3]; void *bias = buffers[3];
void *q_cu_seqlens = buffers[4]; void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5]; void *kv_cu_seqlens = buffers[5];
...@@ -319,16 +316,43 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -319,16 +316,43 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
void *workspace = buffers[12]; void *workspace = buffers[12];
FusedAttnForwardImpl( FusedAttnForwardImpl(
stream, buffers[0], buffers[1], buffers[2], bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed,
k_seq_offsets, seed, output, softmax_aux, rng_state, workspace, descriptor.input_batch, output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
} }
#define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen"); \
size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen"); \
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \
bool deterministic = get_attr_value<bool>(attrs, "deterministic"); \
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \
size_t wkspace_size = product(workspace_buf->dimensions()); \
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); \
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
...@@ -336,37 +360,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty ...@@ -336,37 +360,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
Buffer_Type seed_buf, Result_Type output_buf, Buffer_Type seed_buf, Result_Type output_buf,
Result_Type softmax_aux_buf, Result_Type rng_state_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, Dictionary attrs) { Result_Type workspace_buf, Dictionary attrs) {
/* Descriptor data type conversion */ FUSED_ATTN_FFI_GET_ATTRS;
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch");
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch");
size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen");
size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen");
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads");
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups");
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads");
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim");
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq");
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left");
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right");
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor");
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability");
NVTE_Bias_Type bias_type =
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type"));
NVTE_Mask_Type mask_type =
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type"));
NVTE_QKV_Layout qkv_layout =
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout"));
bool is_training = get_attr_value<bool>(attrs, "is_training");
bool deterministic = get_attr_value<bool>(attrs, "deterministic");
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
size_t wkspace_size = product(workspace_buf->dimensions());
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type());
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
FusedAttnForwardImpl( FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
...@@ -503,81 +497,23 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -503,81 +497,23 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
} }
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { static void FusedAttnBackwardImpl(
const CustomCallFusedAttnDescriptor &descriptor = cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state,
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets,
void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace,
auto qkv_layout = descriptor.qkv_layout; size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
/* Input buffers from XLA */ float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
void *bias = buffers[3]; bool deterministic, int64_t window_size_left, int64_t window_size_right) {
void *softmax_aux = buffers[4]; FUSED_ATTN_IMPL_COMMON_BLOCK;
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */
/* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */
void *dbias = buffers[15];
void *workspace = buffers[16];
/* Descriptor */
auto input_batch = descriptor.input_batch;
auto bias_batch = descriptor.bias_batch;
auto q_max_seqlen = descriptor.q_max_seqlen;
auto kv_max_seqlen = descriptor.kv_max_seqlen;
auto attn_heads = descriptor.attn_heads;
auto num_gqa_groups = descriptor.num_gqa_groups;
auto bias_heads = descriptor.bias_heads;
auto head_dim = descriptor.head_dim;
auto scaling_factor = descriptor.scaling_factor;
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto deterministic = descriptor.deterministic;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
auto window_size_left = descriptor.window_size_left;
auto window_size_right = descriptor.window_size_right;
/* Input tensors */ /* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto output_tensor = TensorWrapper(output, output_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
auto cudnn_runtime_version = cudnnGetVersion();
if (cudnn_runtime_version >= 90300) {
num_segments = input_batch * max_segments_per_seq;
} else {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
...@@ -593,21 +529,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -593,21 +529,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
/* cuDNN workspace */
auto wkspace_size = std::vector<size_t>{descriptor.wkspace_size};
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv = buffers[12]; auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
if (is_ragged) { if (is_ragged) {
cudaMemsetAsync(dqkv, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
...@@ -618,19 +546,15 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -618,19 +546,15 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
bias_type, mask_type, window_size_left, window_size_right, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream); deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto dq = buffers[12]; auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv = buffers[13]; auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
if (is_ragged) { if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dk, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream);
} }
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
...@@ -642,20 +566,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -642,20 +566,14 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream); deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k = buffers[1];
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v = buffers[2];
auto v_shape = k_shape; auto v_shape = k_shape;
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq = buffers[12];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk = buffers[13];
auto dk_tensor = TensorWrapper(dk, k_shape, dtype); auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv = buffers[14];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype); auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) { if (is_ragged) {
cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream);
...@@ -679,5 +597,93 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -679,5 +597,93 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
void *q = buffers[0];
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *softmax_aux = buffers[4];
void *rng_state = buffers[5];
void *output = buffers[6];
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */
void *dq = buffers[12];
void *dk = buffers[13];
void *dv = buffers[14];
void *dbias = buffers[15];
void *workspace = buffers[16];
FusedAttnBackwardImpl(
stream, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlens, kv_cu_seqlens,
q_seq_offsets, k_seq_offsets, dq, dk, dv, dbias, workspace, descriptor.input_batch,
descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen,
descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor,
descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type,
descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training,
descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right);
}
Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf,
Buffer_Type output_buf, Buffer_Type doutput_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnBackwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(),
output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // q
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // softmax_aux
.Arg<Buffer_Type>() // rng_state
.Arg<Buffer_Type>() // output
.Arg<Buffer_Type>() // doutput
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.Ret<Buffer_Type>() // dq
.Ret<Buffer_Type>() // dk
.Ret<Buffer_Type>() // dv
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // workspace
.Attrs(),
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -68,6 +68,19 @@ pybind11::dict Registrations() { ...@@ -68,6 +68,19 @@ pybind11::dict Registrations() {
// Quantization // Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization // Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
...@@ -83,6 +96,11 @@ pybind11::dict Registrations() { ...@@ -83,6 +96,11 @@ pybind11::dict Registrations() {
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler); fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi; dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;
pybind11::dict fused_attn_backward_ffi;
fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler);
dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi;
return dict; return dict;
} }
......
...@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t
auto shape = desc.shape.to_vector(); auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
} }
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end());
auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>(), // output
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/softmax.h" #include "transformer_engine/softmax.h"
#include "extensions.h" #include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
dgrad_tensor.data(), desc.scale_factor, stream); dgrad_tensor.data(), desc.scale_factor, stream);
} }
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto tensor_dims = (tensor_buf).dimensions(); \
auto tensor_ranks = tensor_dims.size(); \
auto batch_size = product(tensor_dims, 0, tensor_ranks - 3); \
auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2); \
auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1); \
auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks); \
float scale_factor = static_cast<float>(scale_factor_);
#define SOFTMAX_FORWARD_COMMON_BLOCK \
auto *input = input_buf.untyped_data(); \
auto *output = output_buf->untyped_data(); \
auto input_tensor = TensorWrapper(input, shape, dtype); \
auto output_tensor = TensorWrapper(output, shape, dtype);
Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type mask_buf, Result_Type output_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
// Mask would be casted to uint8_t
auto *mask = mask_buf.untyped_data();
auto mask_dims = mask_buf.dimensions();
auto padding_size = product(mask_dims, mask_dims.size() - 3);
auto mask_shape = std::vector<size_t>{padding_size, 1, q_seqlen, k_seqlen};
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // mask
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler,
ScaledUpperTriangMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
#define SOFTMAX_BACKWARD_COMMON_BLOCK \
auto *grad_output = grad_output_buf.untyped_data(); \
auto *softmax_output = softmax_output_buf.untyped_data(); \
auto *dgrad = dgrad_buf->untyped_data(); \
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); \
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf, Result_Type dgrad_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream,
Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf,
Result_Type dgrad_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler,
ScaledUpperTriangMaskedSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment