Unverified Commit 0816583a authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Support dropout for the fused attention when max seqlen <= 512 (#227)



* Enable fused attention dropout
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Cast the uint32 key/counter to int64
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update dropout support in fused attention docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revise devPtrCuSeqlen* to align the naming
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support different Jax PRNG impls
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert CastAsync since it is not used
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Implement is_training for 16-bit fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused attn with dropout sanity unit tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the comments readability and rng_state checker
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change the attention dropout shape to align other frameworks
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Make encoder tests deterministic
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change the default seed for the jax encoder tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Maintain offset in TE
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the resource safety
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert rng_state type to allow only i64
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle the corner case for elts_per_threads calculation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Populate rng state by kernels
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename rng_state as seed in cpp_extensions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update the attention dropout comment
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 16208b3b
......@@ -377,7 +377,7 @@ def encoder_parser(args):
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
......
......@@ -359,7 +359,7 @@ def encoder_parser(args):
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
......
......@@ -459,7 +459,7 @@ def encoder_parser(args):
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
......
......@@ -294,7 +294,7 @@ def encoder_parser(args):
default=False,
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8",
action="store_true",
default=False,
......
......@@ -9,5 +9,10 @@ pytest -Wignore -v $TE_PATH/tests/jax
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -Wignore -v $TE_PATH/examples/jax --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="--xla_gpu_deterministic_ops"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
......@@ -54,6 +54,7 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
value,
bias=bias,
mask=mask,
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=qkv.dtype)
......@@ -78,6 +79,7 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
value,
bias=None,
mask=mask,
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=q.dtype)
......@@ -113,7 +115,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
reason="Fused attention kernel is not supported.")
class TestSelfFusedAttnMax512():
def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias):
def set_input(self, b, s, h, d, *, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -125,6 +128,8 @@ class TestSelfFusedAttnMax512():
min_val, max_val = -1, 1
self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val,
max_val) if with_bias else None
......@@ -133,28 +138,81 @@ class TestSelfFusedAttnMax512():
self.kv_token = self.q_token
self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0.
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS if self.bias is None else AttnBiasType.POST_SCALE_BIAS
# deterministic = not is_training
self.deterministic = False
self.attn_bias_type = attn_bias_type
self.is_training = is_training
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
@pytest.mark.parametrize('with_bias', [True, False])
def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias):
def test_sanity(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
def grad_func(func, *args, **kwargs):
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s, h, d)
valid_ret, _ = jnp.split(func(*args, **kwargs), (self.valid_len,), axis=1)
return jnp.mean(valid_ret, dtype=jnp.float32).astype(dtype)
self.set_input(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
}
jitted_primitive = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
), (0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio,
with_bias=with_bias)
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
......@@ -165,7 +223,7 @@ class TestSelfFusedAttnMax512():
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=not self.deterministic)
is_training=self.is_training)
reference_out = jax_self_fused_attn(self.qkv,
self.bias,
......@@ -174,7 +232,8 @@ class TestSelfFusedAttnMax512():
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability)
dropout_probability=self.dropout_probability,
is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
......@@ -188,20 +247,25 @@ class TestSelfFusedAttnMax512():
jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
@pytest.mark.parametrize('with_bias', [True, False])
def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias):
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability,
dtype, is_training, pad_ratio):
self.set_input(b,
s,
h,
d,
dtype=dtype,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio,
with_bias=with_bias)
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
......@@ -221,7 +285,7 @@ class TestSelfFusedAttnMax512():
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': not self.deterministic
'is_training': self.is_training
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
......@@ -300,7 +364,8 @@ class TestSelfFusedAttnMax512():
reason="Fused attention kernel is not supported.")
class TestCrossFusedAttnMax512():
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio):
def set_input(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -321,25 +386,32 @@ class TestCrossFusedAttnMax512():
(b, kv_pad_len))),
axis=-1)
self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0.
self.dropout_rng = jax.random.PRNGKey(0)
self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS
# deterministic = not is_training
self.deterministic = False
self.is_training = is_training
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio):
def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b,
s_q,
s_kv,
h,
d,
dtype=dtype,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
primitive_out = customcall_cross_fused_attn(self.q,
......@@ -351,7 +423,7 @@ class TestCrossFusedAttnMax512():
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability,
is_training=not self.deterministic)
is_training=self.is_training)
reference_out = jax_cross_fused_attn(self.q,
self.kv,
......@@ -360,7 +432,8 @@ class TestCrossFusedAttnMax512():
self.dropout_rng,
attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability)
dropout_probability=self.dropout_probability,
is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
......@@ -375,16 +448,21 @@ class TestCrossFusedAttnMax512():
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio):
def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
self.set_input(b,
s_q,
s_kv,
h,
d,
dtype=dtype,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs):
......@@ -405,7 +483,7 @@ class TestCrossFusedAttnMax512():
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': not self.deterministic
'is_training': self.is_training
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
......
......@@ -167,9 +167,7 @@ def dot_product_attention(query: Array,
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
......
......@@ -22,7 +22,7 @@
#define O_ID 4
#define S_ID 5
#define B_ID 6
#define D_CONST_ID 7
#define DROPOUT_CONST_ID 7
#define S_CONST_ID 8
#define Q_SEQLEN_ID 9
#define K_SEQLEN_ID 10
......@@ -33,6 +33,8 @@
#define MASK_VAL_ID 15
#define dS_ID 16
#define dBias_ID 17
#define DROPOUT_SEED_ID 18
#define DROPOUT_OFFSET_ID 19
#define VIRTUAL_ID 20
......@@ -333,8 +335,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
cudnnDataType_t softmaxOutputType =
(enable_dropout || softmax_output_virtual) ? CUDNN_DATA_FLOAT : tensorType;
cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType;
uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID;
// max (x)
......@@ -427,7 +428,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
}
static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
int64_t d, int64_t seed, double probability,
int64_t d, double probability,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
......@@ -460,8 +461,9 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
.setReorderType(reorder_type)
.build();
// scale after dropout
auto scaleDropoutTensor = tensor_create(tensorType, D_CONST_ID, scale_dim, scale_stride, false,
true); // is by value
auto scaleDropoutTensor =
tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false,
true); // is by value
// after Scale
auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual
......@@ -472,10 +474,16 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
.setBernoulliDistProbability(1.0 - probability)
.build();
auto dropoutSeed =
tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false);
auto dropoutOffset =
tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false);
// Create a rng Node.
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor)
.setSeed(seed)
.setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc)
.build();
......@@ -624,16 +632,14 @@ static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_
return dxTensor;
}
void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrS, void *devPtrO, void *devPtrBias,
void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace,
size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) {
void fused_attn_max_512_fwd_impl(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV,
void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size,
cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) {
try {
constexpr int64_t seed = 0; // TODO(rewang): replace this with device seed/offset
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h,
......@@ -646,10 +652,13 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache;
bool enable_dropout = (dropout_probability != 0.0f);
// softmax auxiliary is only used in the training mode
bool enable_dropout = is_training && (dropout_probability != 0.0f);
NVTE_CHECK(!enable_dropout,
"dropout probability > 0 in fused_attn_max_512 has not been implemented.");
// two conditions that make softmax auxiliary in virtual
// 1. inference mode (not is_training)
// 2. dropout enabled: the auxiliary becomes the dropout output
bool softmax_output_virtual = !is_training || enable_dropout;
// Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
......@@ -667,8 +676,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
createScale(b, h, s_q, s_kv, d, layout, tensorType, ops);
// if bias, we need to memset the S buffer to correctly computate dbias
// WAR: causal_mask without bias needs memset the S buffer
// inference mode doesn't need the S auxiliary
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK);
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training;
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS,
......@@ -683,14 +694,12 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0.");
// TODO(rewang): check whether devPtrS can be removed
bool softmax_output_virtual = enable_dropout; // || devPtrS == nullptr;
auto softmax_output =
createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout,
softmax_output_virtual, tensorType, ops, mask_output);
if (dropout_probability != 0.0f) {
auto dropout_output = createDropout(b, h, s_q, s_kv, d, seed, dropout_probability,
if (enable_dropout) {
auto dropout_output = createDropout(b, h, s_q, s_kv, d, dropout_probability,
tensorType, ops, softmax_output);
createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output);
} else {
......@@ -741,9 +750,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devCuSeqlenQ),
static_cast<const int32_t *>(devCuSeqlenK), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
// change this if you have access to float_min
float negInfinity = -1.0E+10;
......@@ -758,16 +768,17 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &negInfinity));
__half half_cast_scaling_factor{scaling_factor};
__nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor};
if (tensorType == CUDNN_DATA_FLOAT) {
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor));
} else if (tensorType == CUDNN_DATA_HALF) {
__half cast_scaling_factor{scaling_factor};
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &cast_scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &half_cast_scaling_factor));
} else if (tensorType == CUDNN_DATA_BFLOAT16) {
__nv_bfloat16 cast_scaling_factor{scaling_factor};
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &cast_scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &bfloat_cast_scaling_factor));
} else {
std::cerr << "Not supported tensorType." << std::endl;
NVTE_ERROR("Unsupported tensor type.");
}
data_ptrs.insert(std::pair<uint64_t, void *>(O_ID, devPtrO));
......@@ -776,12 +787,30 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
data_ptrs.insert(std::pair<uint64_t, void *>(B_ID, devPtrBias));
}
if (devPtrS != nullptr) {
// if enable_dropout, S is the result after dropout
// if not enable dropout, S is the result after softmax
if (enable_dropout || !softmax_output_virtual) {
data_ptrs.insert(std::pair<uint64_t, void *>(S_ID, devPtrS));
}
__half half_cast_scale_dropout{scale_dropout};
__nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout};
if (enable_dropout) {
data_ptrs.insert(std::pair<uint64_t, void *>(D_CONST_ID, &scale_dropout));
// TODO(rewang): make a util func
if (tensorType == CUDNN_DATA_FLOAT) {
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &scale_dropout));
} else if (tensorType == CUDNN_DATA_HALF) {
data_ptrs.insert(
std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &half_cast_scale_dropout));
} else if (tensorType == CUDNN_DATA_BFLOAT16) {
data_ptrs.insert(
std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout));
} else {
NVTE_ERROR("Unsupported tensor type.");
}
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_OFFSET_ID, devPtrDropoutOffset));
}
auto variantPack = cudnn_frontend::VariantPackBuilder()
......@@ -802,7 +831,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK,
void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias,
void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace,
void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *workspace,
size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) {
try {
......@@ -915,7 +944,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
ops.push_back(std::move(reshape_op));
// scale dropout
auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim,
auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim,
scale_stride, false, true); // is by value
auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim,
p_transpose_stride, true, false);
......@@ -1160,9 +1189,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devCuSeqlenQ),
static_cast<const int32_t *>(devCuSeqlenK), static_cast<int32_t *>(devActualSeqlenQ),
static_cast<int32_t *>(devActualSeqlenK));
b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack
......@@ -1183,13 +1213,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
data_ptrs.insert(std::pair<uint64_t, void *>(dBias_ID, devPtrdBias));
}
NVTE_CHECK(dropout_probability == 0.f,
"dropout probability > 0 in fused_attn_max_512 has not been implemented.");
float zeroVal = 0.0f;
float dropoutScale = 1.0f / (1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void *>(D_CONST_ID, &dropoutScale));
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &dropoutScale));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &zeroVal));
......@@ -1216,8 +1243,6 @@ void fused_attn_max_512_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Only is_training is verified
NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512.");
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
......@@ -1246,23 +1271,22 @@ void fused_attn_max_512_fwd_qkvpacked(
devPtrS = output_S->data.dptr;
}
void *devCuSeqlen = cu_seqlens->data.dptr;
void *devPtrCuSeqlen = cu_seqlens->data.dptr;
// TODO(rewang): dropout seed
// void* devPtrDropoutSeed = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr));
// void* devPtrDropoutOffset = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
// TODO(rewang): replace CPU seed
fused_attn_max_512_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ,
devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devCuSeqlen,
devCuSeqlen, workspace->data.dptr, &workspace_size,
get_cudnn_dtype(QKV_type), stream, handle);
fused_attn_max_512_fwd_impl(
batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset,
workspace->data.dptr, &workspace_size, get_cudnn_dtype(QKV_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1288,8 +1312,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
// Only is_training is verified
NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512.");
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED.");
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
......@@ -1328,20 +1350,19 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
// TODO(rewang): dropout seed
// void* devPtrDropoutSeed = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr));
// void* devPtrDropoutOffset = reinterpret_cast<void *>(
// reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1);
const DType rng_state_type = rng_state->data.dtype;
NVTE_CHECK(rng_state_type == DType::kInt64);
void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset =
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0;
// TODO(rewang): replace CPU seed
fused_attn_max_512_fwd_impl(batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ,
devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devQCuSeqlen,
devKVCuSeqlen, workspace->data.dptr, &workspace_size,
get_cudnn_dtype(q_type), stream, handle);
fused_attn_max_512_fwd_impl(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......
......@@ -256,6 +256,10 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16:
return CUDNN_DATA_HALF;
case DType::kFloat32:
......
......@@ -106,7 +106,7 @@ enum NVTE_Mask_Type {
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -149,7 +149,7 @@ void nvte_fused_attn_fwd_qkvpacked(
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
......@@ -200,7 +200,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......@@ -247,7 +247,7 @@ void nvte_fused_attn_fwd_kvpacked(
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......
......@@ -6,7 +6,7 @@ pybind11_add_module(
transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cu
)
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
......
......@@ -8,6 +8,8 @@ from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
import warnings
import numpy as np
from jaxlib.hlo_helpers import custom_call
import jax.numpy as jnp
......@@ -1679,7 +1681,7 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
grad_outputs, softmax_outputs,
scale_factor)
return out # out is iterable already
return out # out is iterable already
_scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive)
......@@ -1828,7 +1830,7 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
grad_outputs, softmax_outputs,
scale_factor)
return out # out is iterable already
return out # out is iterable already
_scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
......@@ -1962,7 +1964,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs,
scale_factor)
return out # out is iterable already
return out # out is iterable already
_scaled_upper_triang_masked_softmax_bwd_p = \
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
......@@ -1979,6 +1981,27 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou
scale_factor=scale_factor)
def _check_seed(seed, dropout_probability, is_training):
# Jax can't bind None, create a dummy tensor for None
if seed is None:
dropout_enabled = dropout_probability > 0 and is_training
assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
seed = jnp.zeros(2, dtype=jnp.uint32)
if seed.dtype != jnp.uint32:
warnings.warn(
f"Requested {seed.dtype=} is not available, and will be "
f"casted to dtype uint32. "
f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
seed = seed.astype(jnp.uint32)
assert seed.dtype == jnp.uint32
# Only the first 2 u32 elements are taken
assert seed.size >= 2
return seed
class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
"""
Self Fused Attention Max Seqlen 512 Forward Primitive
......@@ -1991,7 +2014,7 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
qkv,
bias,
cu_seqlen, # pylint: disable=unused-argument
rng_state, # pylint: disable=unused-argument
seed, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
......@@ -2020,8 +2043,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
)
@staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
"""
Self fused attention max seqlen 512 fwd lowering rules
"""
......@@ -2036,8 +2059,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type)
ir_rng_state_shape = ir_rng_state_type.shape
ir_seed_type = ir.RankedTensorType(seed.type)
ir_seed_shape = ir_seed_type.shape
batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape
assert nqkv == 3
......@@ -2049,8 +2072,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type)
]
operands = [qkv, bias, cu_seqlen, rng_state]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_rng_state_shape]
operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_seed_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
......@@ -2069,23 +2092,22 @@ _self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimit
def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool):
"""
Wrapper for TE self fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
# Jax can't bind None, create a dummy tensor for None
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
seed = _check_seed(seed, dropout_probability, is_training)
if bias is None:
assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
bias = jnp.zeros(0, dtype=qkv.dtype)
return _self_fused_attn_max_512_fwd_p.bind(qkv,
bias,
cu_seqlen,
rng_state,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -2161,6 +2183,9 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
......@@ -2208,7 +2233,7 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state, # pylint: disable=unused-argument
seed, # pylint: disable=unused-argument
*,
attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument
......@@ -2243,8 +2268,8 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
)
@staticmethod
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type,
attn_mask_type, scaling_factor, dropout_probability, is_training):
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
"""
Cross fused attention max seqlen 512 fwd lowering rules
"""
......@@ -2260,8 +2285,8 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type)
ir_rng_state_shape = ir_rng_state_type.shape
ir_seed_type = ir.RankedTensorType(seed.type)
ir_seed_shape = ir_seed_type.shape
batch, q_max_seqlen, num_head, head_dim = ir_q_shape
kv_max_seqlen = ir_kv_shape[1]
......@@ -2273,9 +2298,9 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output_shape, ir_q_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type)
]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = [
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_rng_state_shape
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_seed_shape
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
......@@ -2296,7 +2321,7 @@ _cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrim
def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, rng_state: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float,
is_training: bool):
......@@ -2304,14 +2329,13 @@ def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: j
Wrapper for TE cross fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
# Jax can't bind None, create a dummy tensor for None
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
seed = _check_seed(seed, dropout_probability, is_training)
return _cross_fused_attn_max_512_fwd_p.bind(q,
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -2391,6 +2415,9 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
......
......@@ -749,7 +749,7 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
void *qkv = buffers[0];
void *bias = buffers[1];
void *cu_seqlens = buffers[2];
void *rng_state = buffers[3];
void *seed = buffers[3];
// output
void *output = buffers[4];
......@@ -778,30 +778,37 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type,
descriptor.mask_type, query_workspace_tensor.data(), stream);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
size_t workspace_size =
// fused attn workspace + workspace for rng_state
auto plan_workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
......@@ -907,7 +914,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
void *kv = buffers[1];
void *q_cu_seqlens = buffers[2];
void *kv_cu_seqlens = buffers[3];
void *rng_state = buffers[4];
void *seed = buffers[4];
// output
void *output = buffers[5];
......@@ -939,7 +946,8 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
......@@ -949,7 +957,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream);
......@@ -957,13 +965,19 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux;
size_t workspace_size =
// fused attn workspace + workspace for rng_state
auto plan_workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
......
......@@ -32,5 +32,23 @@ int GetDeviceComputeCapability(int gpu_id) {
return gpu_arch;
}
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream) {
constexpr int threads_per_cta = 128;
const size_t increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace jax
} // namespace transformer_engine
......@@ -21,6 +21,9 @@ namespace jax {
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream);
class cublasLtMetaManager {
public:
static cublasLtMetaManager &Instance() {
......@@ -93,6 +96,27 @@ class cudaDevicePropertiesManager {
cudaDeviceProp prop_;
};
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
} // namespace jax
} // namespace transformer_engine
......
......@@ -11,6 +11,7 @@ import os
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
......@@ -182,9 +183,8 @@ def core_attention(query: Array,
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
# TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
......@@ -384,7 +384,7 @@ class MultiHeadAttention(nn.Module):
fused_attn_supported_seqlen = [128, 256, 384, 512]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn
......@@ -397,9 +397,6 @@ class MultiHeadAttention(nn.Module):
f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if self.dropout_rate != 0:
# TODO(rewang): add dropout support
reason += f"no dropout is required but got dropout_rate={self.dropout_rate}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, "
......@@ -583,6 +580,12 @@ class MultiHeadAttention(nn.Module):
assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
assert not self.transpose_batch_sequence
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, len(jax.devices()))
# ensure the old key never used
del dropout_rng
# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
......@@ -607,7 +610,7 @@ class MultiHeadAttention(nn.Module):
x = self_fused_attn(qkv_proj,
bias,
mask,
dropout_rng,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
......@@ -626,7 +629,7 @@ class MultiHeadAttention(nn.Module):
x = cross_fused_attn(query,
kv_proj,
mask,
dropout_rng,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
......
......@@ -46,7 +46,7 @@ class AttnMaskType(Enum):
def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
rng_state: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
......@@ -63,7 +63,7 @@ def self_fused_attn(qkv: jnp.ndarray,
output = _self_fused_attn_max_512(qkv,
bias,
mask,
rng_state,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -73,13 +73,13 @@ def self_fused_attn(qkv: jnp.ndarray,
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [qkv, bias, mask, rng_state]
inputs = [qkv, bias, mask, seed]
batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, None], [0]),
tp_dims=([3, 1, None, None], [2]),
dp_dims=([0, None, 0, 0], [0]),
tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
......@@ -104,13 +104,13 @@ def self_fused_attn(qkv: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType,
seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_max_512_fwd(qkv,
bias,
mask,
rng_state,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -119,7 +119,7 @@ def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndar
return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, attn_mask_type,
def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
......@@ -129,7 +129,7 @@ def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, att
output, softmax_aux = self_fused_attn_max_512_fwd(qkv,
bias,
cu_seqlen,
rng_state,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......@@ -163,7 +163,7 @@ _self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_m
def cross_fused_attn(q: jnp.ndarray,
kv: jnp.ndarray,
mask: jnp.ndarray,
rng_state: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
......@@ -180,7 +180,7 @@ def cross_fused_attn(q: jnp.ndarray,
output = _cross_fused_attn_max_512(q,
kv,
mask,
rng_state,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -190,7 +190,7 @@ def cross_fused_attn(q: jnp.ndarray,
dp_axis_name = "batch"
tp_axis_name = "model"
inputs = [q, kv, mask, rng_state]
inputs = [q, kv, mask, seed]
output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
......@@ -219,15 +219,14 @@ def cross_fused_attn(q: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool):
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
scaling_factor: float, dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q,
kv,
mask,
rng_state,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -236,8 +235,8 @@ def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray
return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training):
def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen)
......@@ -251,7 +250,7 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_m
kv,
q_cu_seqlen,
kv_cu_seqlen,
rng_state,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment