"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "7f22f90e8cb423fdaa35203d41badd734d9c2e86"
Unverified Commit 237b4930 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for Softmax and FusedAttnBackward (#1319)



* FFI for all softmax functions
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* FFI for FusedAttnBackward and Dequantize

FusedAttnBackward passed all testes in test_fused_attn.py.
Dequantize is not used currently; finish it for completeness.
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Fix FusedAttnBackward FFI pybind & simplify
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert changes to tests/jax/test_fused_attn.py
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 68adf451
...@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -380,8 +380,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled(): if is_ffi_enabled():
name = "te_fused_attn_forward_ffi" name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)( out = ffi.ffi_lowering(name)(
...@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -433,6 +431,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch, input_batch,
bias_batch, bias_batch,
...@@ -725,6 +725,56 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -725,6 +725,56 @@ class FusedAttnBwdPrimitive(BasePrimitive):
""" """
Fused attention bwd lowering rules Fused attention bwd lowering rules
""" """
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if is_ffi_enabled():
name = "te_fused_attn_backward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [ operands = [
q, q,
k, k,
...@@ -744,23 +794,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -744,23 +794,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out for output in ctx.avals_out
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
wkspace_aval = ctx.avals_out[-1] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
......
...@@ -12,12 +12,13 @@ import jax.numpy as jnp ...@@ -12,12 +12,13 @@ import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
...@@ -133,6 +134,10 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -133,6 +134,10 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_forward lowering rules softmax_forward lowering rules
""" """
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
else:
(i_aval,) = ctx.avals_in (i_aval,) = ctx.avals_in
i_type = ir.RankedTensorType(logits.type) i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape i_shape = i_type.shape
...@@ -240,6 +245,10 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -240,6 +245,10 @@ class SoftmaxPrimitive(BasePrimitive):
""" """
softmax_backward lowering rules softmax_backward lowering rules
""" """
if is_ffi_enabled():
ffi_name = name + "_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
else:
dz_aval, _ = ctx.avals_in dz_aval, _ = ctx.avals_in
dz_type = ir.RankedTensorType(dz.type) dz_type = ir.RankedTensorType(dz.type)
...@@ -577,7 +586,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -577,7 +586,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
""" """
te_scaled_masked_softmax_forward lowering rules te_scaled_masked_softmax_forward lowering rules
""" """
if is_ffi_enabled():
ffi_name = "te_scaled_masked_softmax_forward_ffi"
out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
else:
logits_aval, _ = ctx.avals_in logits_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(logits.type) i_type = ir.RankedTensorType(logits.type)
i_shape = i_type.shape i_shape = i_type.shape
......
...@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler); ...@@ -238,6 +238,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
// Softmax // Softmax
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
...@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, ...@@ -258,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers,
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len); std::size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Attention // Attention
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
...@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -289,8 +306,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -68,6 +68,19 @@ pybind11::dict Registrations() { ...@@ -68,6 +68,19 @@ pybind11::dict Registrations() {
// Quantization // Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization // Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
...@@ -83,6 +96,11 @@ pybind11::dict Registrations() { ...@@ -83,6 +96,11 @@ pybind11::dict Registrations() {
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler); fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi; dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;
pybind11::dict fused_attn_backward_ffi;
fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler);
dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi;
return dict; return dict;
} }
......
...@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t ...@@ -74,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t
auto shape = desc.shape.to_vector(); auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
} }
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
std::vector<size_t> shape(input_dims.begin(), input_dims.end());
auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>(), // output
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/softmax.h" #include "transformer_engine/softmax.h"
#include "extensions.h" #include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, ...@@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers,
dgrad_tensor.data(), desc.scale_factor, stream); dgrad_tensor.data(), desc.scale_factor, stream);
} }
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto tensor_dims = (tensor_buf).dimensions(); \
auto tensor_ranks = tensor_dims.size(); \
auto batch_size = product(tensor_dims, 0, tensor_ranks - 3); \
auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2); \
auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1); \
auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks); \
float scale_factor = static_cast<float>(scale_factor_);
#define SOFTMAX_FORWARD_COMMON_BLOCK \
auto *input = input_buf.untyped_data(); \
auto *output = output_buf->untyped_data(); \
auto input_tensor = TensorWrapper(input, shape, dtype); \
auto output_tensor = TensorWrapper(output, shape, dtype);
Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type mask_buf, Result_Type output_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
// Mask would be casted to uint8_t
auto *mask = mask_buf.untyped_data();
auto mask_dims = mask_buf.dimensions();
auto padding_size = product(mask_dims, mask_dims.size() - 3);
auto mask_shape = std::vector<size_t>{padding_size, 1, q_seqlen, k_seqlen};
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
Result_Type output_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(input_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_FORWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // mask
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler,
ScaledUpperTriangMaskedSoftmaxForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
#define SOFTMAX_BACKWARD_COMMON_BLOCK \
auto *grad_output = grad_output_buf.untyped_data(); \
auto *softmax_output = softmax_output_buf.untyped_data(); \
auto *dgrad = dgrad_buf->untyped_data(); \
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); \
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf, Result_Type dgrad_buf,
double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream,
Buffer_Type grad_output_buf,
Buffer_Type softmax_output_buf,
Result_Type dgrad_buf, double scale_factor_) {
SOFTMAX_COMMON_BLOCK(grad_output_buf);
auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
SOFTMAX_BACKWARD_COMMON_BLOCK;
nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
softmax_output_tensor.data(),
dgrad_tensor.data(), scale_factor, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler,
ScaledUpperTriangMaskedSoftmaxBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // grad_output
.Arg<Buffer_Type>() // softmax_output
.Ret<Buffer_Type>() // dgrad
.Attr<double>("scale_factor"),
FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment