"vscode:/vscode.git/clone" did not exist on "7e7591745670d8ad2593b265ca192ca3d7a6112e"
Unverified Commit 0816583a authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Support dropout for the fused attention when max seqlen <= 512 (#227)



* Enable fused attention dropout
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Cast the uint32 key/counter to int64
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update dropout support in fused attention docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revise devPtrCuSeqlen* to align the naming
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Support different Jax PRNG impls
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert CastAsync since it is not used
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Implement is_training for 16-bit fused attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add fused attn with dropout sanity unit tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the comments readability and rng_state checker
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change the attention dropout shape to align other frameworks
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Make encoder tests deterministic
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change the default seed for the jax encoder tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Maintain offset in TE
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance the resource safety
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revert rng_state type to allow only i64
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Handle the corner case for elts_per_threads calculation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Populate rng state by kernels
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename rng_state as seed in cpp_extensions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update the attention dropout comment
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 16208b3b
...@@ -377,7 +377,7 @@ def encoder_parser(args): ...@@ -377,7 +377,7 @@ def encoder_parser(args):
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument("--use-fp8",
action="store_true", action="store_true",
default=False, default=False,
......
...@@ -359,7 +359,7 @@ def encoder_parser(args): ...@@ -359,7 +359,7 @@ def encoder_parser(args):
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument("--use-fp8",
action="store_true", action="store_true",
default=False, default=False,
......
...@@ -459,7 +459,7 @@ def encoder_parser(args): ...@@ -459,7 +459,7 @@ def encoder_parser(args):
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument("--use-fp8",
action="store_true", action="store_true",
default=False, default=False,
......
...@@ -294,7 +294,7 @@ def encoder_parser(args): ...@@ -294,7 +294,7 @@ def encoder_parser(args):
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
parser.add_argument("--use-fp8", parser.add_argument("--use-fp8",
action="store_true", action="store_true",
default=False, default=False,
......
...@@ -9,5 +9,10 @@ pytest -Wignore -v $TE_PATH/tests/jax ...@@ -9,5 +9,10 @@ pytest -Wignore -v $TE_PATH/tests/jax
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -Wignore -v $TE_PATH/examples/jax --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="--xla_gpu_deterministic_ops"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
...@@ -54,6 +54,7 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -54,6 +54,7 @@ def jax_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
value, value,
bias=bias, bias=bias,
mask=mask, mask=mask,
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=qkv.dtype) dtype=qkv.dtype)
...@@ -78,6 +79,7 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs): ...@@ -78,6 +79,7 @@ def jax_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
value, value,
bias=None, bias=None,
mask=mask, mask=mask,
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=q.dtype) dtype=q.dtype)
...@@ -113,7 +115,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs) ...@@ -113,7 +115,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
reason="Fused attention kernel is not supported.") reason="Fused attention kernel is not supported.")
class TestSelfFusedAttnMax512(): class TestSelfFusedAttnMax512():
def set_input(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): def set_input(self, b, s, h, d, *, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -125,6 +128,8 @@ class TestSelfFusedAttnMax512(): ...@@ -125,6 +128,8 @@ class TestSelfFusedAttnMax512():
min_val, max_val = -1, 1 min_val, max_val = -1, 1
self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val) self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val, self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val,
max_val) if with_bias else None max_val) if with_bias else None
...@@ -133,28 +138,81 @@ class TestSelfFusedAttnMax512(): ...@@ -133,28 +138,81 @@ class TestSelfFusedAttnMax512():
self.kv_token = self.q_token self.kv_token = self.q_token
self.scaling_factor = 1. / math.sqrt(d) self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0. self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS if self.bias is None else AttnBiasType.POST_SCALE_BIAS self.attn_bias_type = attn_bias_type
# deterministic = not is_training self.is_training = is_training
self.deterministic = False
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) @pytest.mark.parametrize('pad_ratio', PAD_RATIO)
@pytest.mark.parametrize('with_bias', [True, False]) def test_sanity(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
def test_forward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): is_training, pad_ratio):
def grad_func(func, *args, **kwargs):
# Keep only valid result for the gradient
# fused_attn_max_512 output has shape (b, s, h, d)
valid_ret, _ = jnp.split(func(*args, **kwargs), (self.valid_len,), axis=1)
return jnp.mean(valid_ret, dtype=jnp.float32).astype(dtype)
self.set_input(b, self.set_input(b,
s, s,
h, h,
d, d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype, dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability,
'is_training': self.is_training
}
jitted_primitive = jit(
value_and_grad(
lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
), (0, 1)))
primitive_out, (primitive_dqkv,
primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
self.kv_token, self.dropout_rng)
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b,
s,
h,
d,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio, dropout_probability=dropout_probability,
with_bias=with_bias) dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
primitive_out = customcall_self_fused_attn(self.qkv, primitive_out = customcall_self_fused_attn(self.qkv,
self.bias, self.bias,
...@@ -165,7 +223,7 @@ class TestSelfFusedAttnMax512(): ...@@ -165,7 +223,7 @@ class TestSelfFusedAttnMax512():
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor, scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability, dropout_probability=self.dropout_probability,
is_training=not self.deterministic) is_training=self.is_training)
reference_out = jax_self_fused_attn(self.qkv, reference_out = jax_self_fused_attn(self.qkv,
self.bias, self.bias,
...@@ -174,7 +232,8 @@ class TestSelfFusedAttnMax512(): ...@@ -174,7 +232,8 @@ class TestSelfFusedAttnMax512():
self.dropout_rng, self.dropout_rng,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor, scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability) dropout_probability=self.dropout_probability,
is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1) ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1) pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)
...@@ -188,20 +247,25 @@ class TestSelfFusedAttnMax512(): ...@@ -188,20 +247,25 @@ class TestSelfFusedAttnMax512():
jnp.zeros_like(pri_invalid, jnp.float32)) jnp.zeros_like(pri_invalid, jnp.float32))
@pytest.mark.parametrize('b, s, h, d', SELF_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) @pytest.mark.parametrize('pad_ratio', PAD_RATIO)
@pytest.mark.parametrize('with_bias', [True, False]) def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, dropout_probability,
def test_forward_backward(self, b, s, h, d, dtype, attn_mask_type, pad_ratio, with_bias): dtype, is_training, pad_ratio):
self.set_input(b, self.set_input(b,
s, s,
h, h,
d, d,
dtype=dtype, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
pad_ratio=pad_ratio, dropout_probability=dropout_probability,
with_bias=with_bias) dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs): def grad_func(fused_attn_max_512_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident # Gradient is small, use a gradient multiplier to amplify the graident
...@@ -221,7 +285,7 @@ class TestSelfFusedAttnMax512(): ...@@ -221,7 +285,7 @@ class TestSelfFusedAttnMax512():
'attn_mask_type': attn_mask_type, 'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor, 'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability, 'dropout_probability': self.dropout_probability,
'is_training': not self.deterministic 'is_training': self.is_training
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
...@@ -300,7 +364,8 @@ class TestSelfFusedAttnMax512(): ...@@ -300,7 +364,8 @@ class TestSelfFusedAttnMax512():
reason="Fused attention kernel is not supported.") reason="Fused attention kernel is not supported.")
class TestCrossFusedAttnMax512(): class TestCrossFusedAttnMax512():
def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): def set_input(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -321,25 +386,32 @@ class TestCrossFusedAttnMax512(): ...@@ -321,25 +386,32 @@ class TestCrossFusedAttnMax512():
(b, kv_pad_len))), (b, kv_pad_len))),
axis=-1) axis=-1)
self.scaling_factor = 1. / math.sqrt(d) self.scaling_factor = 1. / math.sqrt(d)
self.dropout_probability = 0. self.dropout_probability = dropout_probability
self.dropout_rng = jax.random.PRNGKey(0) self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
self.attn_bias_type = AttnBiasType.NO_BIAS self.attn_bias_type = AttnBiasType.NO_BIAS
# deterministic = not is_training self.is_training = is_training
self.deterministic = False
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES) @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK]) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) @pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
# dropout can't get the bitmatch result
if is_training and dropout_probability > 0.:
return
self.set_input(b, self.set_input(b,
s_q, s_q,
s_kv, s_kv,
h, h,
d, d,
dtype=dtype,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio) pad_ratio=pad_ratio)
primitive_out = customcall_cross_fused_attn(self.q, primitive_out = customcall_cross_fused_attn(self.q,
...@@ -351,7 +423,7 @@ class TestCrossFusedAttnMax512(): ...@@ -351,7 +423,7 @@ class TestCrossFusedAttnMax512():
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor, scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability, dropout_probability=self.dropout_probability,
is_training=not self.deterministic) is_training=self.is_training)
reference_out = jax_cross_fused_attn(self.q, reference_out = jax_cross_fused_attn(self.q,
self.kv, self.kv,
...@@ -360,7 +432,8 @@ class TestCrossFusedAttnMax512(): ...@@ -360,7 +432,8 @@ class TestCrossFusedAttnMax512():
self.dropout_rng, self.dropout_rng,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=self.scaling_factor, scaling_factor=self.scaling_factor,
dropout_probability=self.dropout_probability) dropout_probability=self.dropout_probability,
is_training=self.is_training)
ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1) ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1) pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)
...@@ -375,16 +448,21 @@ class TestCrossFusedAttnMax512(): ...@@ -375,16 +448,21 @@ class TestCrossFusedAttnMax512():
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES) @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK]) @pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0.]) # dropout can't get the bitmatch result
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True]) # backward is only used when is_training
@pytest.mark.parametrize('pad_ratio', PAD_RATIO) @pytest.mark.parametrize('pad_ratio', PAD_RATIO)
def test_forward_backward(self, b, s_q, s_kv, h, d, dtype, attn_mask_type, pad_ratio): def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
is_training, pad_ratio):
self.set_input(b, self.set_input(b,
s_q, s_q,
s_kv, s_kv,
h, h,
d, d,
dtype=dtype,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio) pad_ratio=pad_ratio)
def grad_func(fused_attn_max_512_func, *args, **kwargs): def grad_func(fused_attn_max_512_func, *args, **kwargs):
...@@ -405,7 +483,7 @@ class TestCrossFusedAttnMax512(): ...@@ -405,7 +483,7 @@ class TestCrossFusedAttnMax512():
'attn_mask_type': attn_mask_type, 'attn_mask_type': attn_mask_type,
'scaling_factor': self.scaling_factor, 'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_probability, 'dropout_probability': self.dropout_probability,
'is_training': not self.deterministic 'is_training': self.is_training
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
......
...@@ -167,9 +167,7 @@ def dot_product_attention(query: Array, ...@@ -167,9 +167,7 @@ def dot_product_attention(query: Array,
# T5 broadcasts along the "length" dim, but unclear which one that # T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim. # corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#define O_ID 4 #define O_ID 4
#define S_ID 5 #define S_ID 5
#define B_ID 6 #define B_ID 6
#define D_CONST_ID 7 #define DROPOUT_CONST_ID 7
#define S_CONST_ID 8 #define S_CONST_ID 8
#define Q_SEQLEN_ID 9 #define Q_SEQLEN_ID 9
#define K_SEQLEN_ID 10 #define K_SEQLEN_ID 10
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
#define MASK_VAL_ID 15 #define MASK_VAL_ID 15
#define dS_ID 16 #define dS_ID 16
#define dBias_ID 17 #define dBias_ID 17
#define DROPOUT_SEED_ID 18
#define DROPOUT_OFFSET_ID 19
#define VIRTUAL_ID 20 #define VIRTUAL_ID 20
...@@ -333,8 +335,7 @@ static cudnn_frontend::Tensor createSoftmaxForward( ...@@ -333,8 +335,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
int64_t afterReduction_dim[4] = {b, h, s_q, 1}; int64_t afterReduction_dim[4] = {b, h, s_q, 1};
int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1};
cudnnDataType_t softmaxOutputType = cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType;
(enable_dropout || softmax_output_virtual) ? CUDNN_DATA_FLOAT : tensorType;
uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID; uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID;
// max (x) // max (x)
...@@ -427,7 +428,7 @@ static cudnn_frontend::Tensor createSoftmaxForward( ...@@ -427,7 +428,7 @@ static cudnn_frontend::Tensor createSoftmaxForward(
} }
static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
int64_t d, int64_t seed, double probability, int64_t d, double probability,
cudnnDataType_t tensorType, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references) // NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops, std::vector<cudnn_frontend::Operation> &ops,
...@@ -460,8 +461,9 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i ...@@ -460,8 +461,9 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
.setReorderType(reorder_type) .setReorderType(reorder_type)
.build(); .build();
// scale after dropout // scale after dropout
auto scaleDropoutTensor = tensor_create(tensorType, D_CONST_ID, scale_dim, scale_stride, false, auto scaleDropoutTensor =
true); // is by value tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false,
true); // is by value
// after Scale // after Scale
auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim, auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim,
afterBMM1_stride, true, false); // is virtual afterBMM1_stride, true, false); // is virtual
...@@ -472,10 +474,16 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i ...@@ -472,10 +474,16 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
.setBernoulliDistProbability(1.0 - probability) .setBernoulliDistProbability(1.0 - probability)
.build(); .build();
auto dropoutSeed =
tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false);
auto dropoutOffset =
tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false);
// Create a rng Node. // Create a rng Node.
auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR)
.setyDesc(dropoutMaskTensor) .setyDesc(dropoutMaskTensor)
.setSeed(seed) .setSeedDesc(dropoutSeed)
.setOffsetDesc(dropoutOffset)
.setRngDesc(rngDesc) .setRngDesc(rngDesc)
.build(); .build();
...@@ -624,16 +632,14 @@ static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_ ...@@ -624,16 +632,14 @@ static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_
return dxTensor; return dxTensor;
} }
void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, void fused_attn_max_512_fwd_impl(
bool is_training, float scaling_factor, float dropout_probability, int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool is_training,
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV,
void *devPtrV, void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV,
void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size,
size_t *workspace_size, cudnnDataType_t tensorType, cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) {
cudaStream_t stream, cudnnHandle_t handle) {
try { try {
constexpr int64_t seed = 0; // TODO(rewang): replace this with device seed/offset
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
FADescriptor descriptor{b, h, FADescriptor descriptor{b, h,
...@@ -646,10 +652,13 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -646,10 +652,13 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>; using CacheType = std::map<FADescriptor, cudnn_frontend::ExecutionPlan>;
static thread_local CacheType fmha_fprop_cache; static thread_local CacheType fmha_fprop_cache;
bool enable_dropout = (dropout_probability != 0.0f); // softmax auxiliary is only used in the training mode
bool enable_dropout = is_training && (dropout_probability != 0.0f);
NVTE_CHECK(!enable_dropout, // two conditions that make softmax auxiliary in virtual
"dropout probability > 0 in fused_attn_max_512 has not been implemented."); // 1. inference mode (not is_training)
// 2. dropout enabled: the auxiliary becomes the dropout output
bool softmax_output_virtual = !is_training || enable_dropout;
// Get plan from cache if cache is available, otherwise create one // Get plan from cache if cache is available, otherwise create one
auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) {
...@@ -667,8 +676,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -667,8 +676,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); createScale(b, h, s_q, s_kv, d, layout, tensorType, ops);
// if bias, we need to memset the S buffer to correctly computate dbias // if bias, we need to memset the S buffer to correctly computate dbias
// WAR: causal_mask without bias needs memset the S buffer
// inference mode doesn't need the S auxiliary
auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) ||
(mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK); (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) && is_training;
auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops);
NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS,
...@@ -683,14 +694,12 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -683,14 +694,12 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0."); NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0.");
// TODO(rewang): check whether devPtrS can be removed
bool softmax_output_virtual = enable_dropout; // || devPtrS == nullptr;
auto softmax_output = auto softmax_output =
createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout, createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout,
softmax_output_virtual, tensorType, ops, mask_output); softmax_output_virtual, tensorType, ops, mask_output);
if (dropout_probability != 0.0f) { if (enable_dropout) {
auto dropout_output = createDropout(b, h, s_q, s_kv, d, seed, dropout_probability, auto dropout_output = createDropout(b, h, s_q, s_kv, d, dropout_probability,
tensorType, ops, softmax_output); tensorType, ops, softmax_output);
createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output); createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output);
} else { } else {
...@@ -741,9 +750,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -741,9 +750,10 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t); void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>( cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devCuSeqlenQ), b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devCuSeqlenK), static_cast<int32_t *>(devActualSeqlenQ), static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenK)); static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
// change this if you have access to float_min // change this if you have access to float_min
float negInfinity = -1.0E+10; float negInfinity = -1.0E+10;
...@@ -758,16 +768,17 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -758,16 +768,17 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK)); data_ptrs.insert(std::pair<uint64_t, void *>(K_SEQLEN_ID, devActualSeqlenK));
data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &negInfinity)); data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &negInfinity));
__half half_cast_scaling_factor{scaling_factor};
__nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor};
if (tensorType == CUDNN_DATA_FLOAT) { if (tensorType == CUDNN_DATA_FLOAT) {
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor)); data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor));
} else if (tensorType == CUDNN_DATA_HALF) { } else if (tensorType == CUDNN_DATA_HALF) {
__half cast_scaling_factor{scaling_factor}; data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &half_cast_scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &cast_scaling_factor));
} else if (tensorType == CUDNN_DATA_BFLOAT16) { } else if (tensorType == CUDNN_DATA_BFLOAT16) {
__nv_bfloat16 cast_scaling_factor{scaling_factor}; data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &bfloat_cast_scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &cast_scaling_factor));
} else { } else {
std::cerr << "Not supported tensorType." << std::endl; NVTE_ERROR("Unsupported tensor type.");
} }
data_ptrs.insert(std::pair<uint64_t, void *>(O_ID, devPtrO)); data_ptrs.insert(std::pair<uint64_t, void *>(O_ID, devPtrO));
...@@ -776,12 +787,30 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -776,12 +787,30 @@ void fused_attn_max_512_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
data_ptrs.insert(std::pair<uint64_t, void *>(B_ID, devPtrBias)); data_ptrs.insert(std::pair<uint64_t, void *>(B_ID, devPtrBias));
} }
if (devPtrS != nullptr) { // if enable_dropout, S is the result after dropout
// if not enable dropout, S is the result after softmax
if (enable_dropout || !softmax_output_virtual) {
data_ptrs.insert(std::pair<uint64_t, void *>(S_ID, devPtrS)); data_ptrs.insert(std::pair<uint64_t, void *>(S_ID, devPtrS));
} }
__half half_cast_scale_dropout{scale_dropout};
__nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout};
if (enable_dropout) { if (enable_dropout) {
data_ptrs.insert(std::pair<uint64_t, void *>(D_CONST_ID, &scale_dropout)); // TODO(rewang): make a util func
if (tensorType == CUDNN_DATA_FLOAT) {
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &scale_dropout));
} else if (tensorType == CUDNN_DATA_HALF) {
data_ptrs.insert(
std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &half_cast_scale_dropout));
} else if (tensorType == CUDNN_DATA_BFLOAT16) {
data_ptrs.insert(
std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout));
} else {
NVTE_ERROR("Unsupported tensor type.");
}
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_SEED_ID, devPtrDropoutSeed));
data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_OFFSET_ID, devPtrDropoutOffset));
} }
auto variantPack = cudnn_frontend::VariantPackBuilder() auto variantPack = cudnn_frontend::VariantPackBuilder()
...@@ -802,7 +831,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -802,7 +831,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK, NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK, void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK,
void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias, void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias,
void *devCuSeqlenQ, void *devCuSeqlenK, void *workspace, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *workspace,
size_t *workspace_size, cudnnDataType_t tensorType, size_t *workspace_size, cudnnDataType_t tensorType,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
try { try {
...@@ -915,7 +944,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -915,7 +944,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
ops.push_back(std::move(reshape_op)); ops.push_back(std::move(reshape_op));
// scale dropout // scale dropout
auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, D_CONST_ID, scale_dim, auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim,
scale_stride, false, true); // is by value scale_stride, false, true); // is by value
auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim, auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim,
p_transpose_stride, true, false); p_transpose_stride, true, false);
...@@ -1160,9 +1189,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1160,9 +1189,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size; void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t); void *devActualSeqlenK = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>( cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devCuSeqlenQ), b, static_cast<const int32_t *>(devPtrCuSeqlenQ),
static_cast<const int32_t *>(devCuSeqlenK), static_cast<int32_t *>(devActualSeqlenQ), static_cast<const int32_t *>(devPtrCuSeqlenKV),
static_cast<int32_t *>(devActualSeqlenK)); static_cast<int32_t *>(devActualSeqlenQ), static_cast<int32_t *>(devActualSeqlenK));
NVTE_CHECK_CUDA(cudaGetLastError());
std::set<std::pair<uint64_t, void *>> data_ptrs; std::set<std::pair<uint64_t, void *>> data_ptrs;
// add all the data pointers to be used in the variant pack // add all the data pointers to be used in the variant pack
...@@ -1183,13 +1213,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv ...@@ -1183,13 +1213,10 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
data_ptrs.insert(std::pair<uint64_t, void *>(dBias_ID, devPtrdBias)); data_ptrs.insert(std::pair<uint64_t, void *>(dBias_ID, devPtrdBias));
} }
NVTE_CHECK(dropout_probability == 0.f,
"dropout probability > 0 in fused_attn_max_512 has not been implemented.");
float zeroVal = 0.0f; float zeroVal = 0.0f;
float dropoutScale = 1.0f / (1.0f - dropout_probability); float dropoutScale = 1.0f / (1.0f - dropout_probability);
data_ptrs.insert(std::pair<uint64_t, void *>(D_CONST_ID, &dropoutScale)); data_ptrs.insert(std::pair<uint64_t, void *>(DROPOUT_CONST_ID, &dropoutScale));
data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor)); data_ptrs.insert(std::pair<uint64_t, void *>(S_CONST_ID, &scaling_factor));
data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &zeroVal)); data_ptrs.insert(std::pair<uint64_t, void *>(MASK_VAL_ID, &zeroVal));
...@@ -1216,8 +1243,6 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1216,8 +1243,6 @@ void fused_attn_max_512_fwd_qkvpacked(
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
// Only is_training is verified
NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512.");
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED."); "qkv_layout must be NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED.");
...@@ -1246,23 +1271,22 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1246,23 +1271,22 @@ void fused_attn_max_512_fwd_qkvpacked(
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} }
void *devCuSeqlen = cu_seqlens->data.dptr; void *devPtrCuSeqlen = cu_seqlens->data.dptr;
// TODO(rewang): dropout seed const DType rng_state_type = rng_state->data.dtype;
// void* devPtrDropoutSeed = reinterpret_cast<void *>( NVTE_CHECK(rng_state_type == DType::kInt64);
// reinterpret_cast<uint64_t*>(rng_state->data.dptr)); void *devPtrDropoutSeed = rng_state->data.dptr;
// void* devPtrDropoutOffset = reinterpret_cast<void *>( void *devPtrDropoutOffset =
// reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
// TODO(rewang): replace CPU seed fused_attn_max_512_fwd_impl(
fused_attn_max_512_fwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devCuSeqlen, devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset,
devCuSeqlen, workspace->data.dptr, &workspace_size, workspace->data.dptr, &workspace_size, get_cudnn_dtype(QKV_type), stream, handle);
get_cudnn_dtype(QKV_type), stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1288,8 +1312,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1288,8 +1312,6 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
// Only is_training is verified
NVTE_CHECK(is_training, "is_training=False is not implemented in fused_attn_max_512.");
NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, NVTE_CHECK(qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED,
"qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED."); "qkv_layout must be NVTE_QKV_Layout::NVTE_KV_INTERLEAVED.");
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
...@@ -1328,20 +1350,19 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1328,20 +1350,19 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devQCuSeqlen = q_cu_seqlens->data.dptr; void *devQCuSeqlen = q_cu_seqlens->data.dptr;
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
// TODO(rewang): dropout seed const DType rng_state_type = rng_state->data.dtype;
// void* devPtrDropoutSeed = reinterpret_cast<void *>( NVTE_CHECK(rng_state_type == DType::kInt64);
// reinterpret_cast<uint64_t*>(rng_state->data.dptr)); void *devPtrDropoutSeed = rng_state->data.dptr;
// void* devPtrDropoutOffset = reinterpret_cast<void *>( void *devPtrDropoutOffset =
// reinterpret_cast<uint64_t*>(rng_state->data.dptr) + 1); static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
size_t workspace_size = 0; size_t workspace_size = 0;
// TODO(rewang): replace CPU seed fused_attn_max_512_fwd_impl(
fused_attn_max_512_fwd_impl(batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, devQCuSeqlen, devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
devKVCuSeqlen, workspace->data.dptr, &workspace_size, &workspace_size, get_cudnn_dtype(q_type), stream, handle);
get_cudnn_dtype(q_type), stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -256,6 +256,10 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, ...@@ -256,6 +256,10 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kInt32:
return CUDNN_DATA_INT32;
case DType::kInt64:
return CUDNN_DATA_INT64;
case DType::kFloat16: case DType::kFloat16:
return CUDNN_DATA_HALF; return CUDNN_DATA_HALF;
case DType::kFloat32: case DType::kFloat32:
......
...@@ -106,7 +106,7 @@ enum NVTE_Mask_Type { ...@@ -106,7 +106,7 @@ enum NVTE_Mask_Type {
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -149,7 +149,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -149,7 +149,7 @@ void nvte_fused_attn_fwd_qkvpacked(
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
...@@ -200,7 +200,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -200,7 +200,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
...@@ -247,7 +247,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -247,7 +247,7 @@ void nvte_fused_attn_fwd_kvpacked(
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
......
...@@ -6,7 +6,7 @@ pybind11_add_module( ...@@ -6,7 +6,7 @@ pybind11_add_module(
transformer_engine_jax transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cu
) )
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine) target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
......
...@@ -8,6 +8,8 @@ from dataclasses import dataclass ...@@ -8,6 +8,8 @@ from dataclasses import dataclass
from typing import Tuple from typing import Tuple
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import warnings
import numpy as np import numpy as np
from jaxlib.hlo_helpers import custom_call from jaxlib.hlo_helpers import custom_call
import jax.numpy as jnp import jax.numpy as jnp
...@@ -1679,7 +1681,7 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1679,7 +1681,7 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
grad_outputs, softmax_outputs, grad_outputs, softmax_outputs,
scale_factor) scale_factor)
return out # out is iterable already return out # out is iterable already
_scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive) _scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive)
...@@ -1828,7 +1830,7 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1828,7 +1830,7 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
grad_outputs, softmax_outputs, grad_outputs, softmax_outputs,
scale_factor) scale_factor)
return out # out is iterable already return out # out is iterable already
_scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive) _scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
...@@ -1962,7 +1964,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -1962,7 +1964,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs, ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs,
scale_factor) scale_factor)
return out # out is iterable already return out # out is iterable already
_scaled_upper_triang_masked_softmax_bwd_p = \ _scaled_upper_triang_masked_softmax_bwd_p = \
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
...@@ -1979,6 +1981,27 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou ...@@ -1979,6 +1981,27 @@ def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_ou
scale_factor=scale_factor) scale_factor=scale_factor)
def _check_seed(seed, dropout_probability, is_training):
# Jax can't bind None, create a dummy tensor for None
if seed is None:
dropout_enabled = dropout_probability > 0 and is_training
assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
seed = jnp.zeros(2, dtype=jnp.uint32)
if seed.dtype != jnp.uint32:
warnings.warn(
f"Requested {seed.dtype=} is not available, and will be "
f"casted to dtype uint32. "
f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
seed = seed.astype(jnp.uint32)
assert seed.dtype == jnp.uint32
# Only the first 2 u32 elements are taken
assert seed.size >= 2
return seed
class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
""" """
Self Fused Attention Max Seqlen 512 Forward Primitive Self Fused Attention Max Seqlen 512 Forward Primitive
...@@ -1991,7 +2014,7 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -1991,7 +2014,7 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
qkv, qkv,
bias, bias,
cu_seqlen, # pylint: disable=unused-argument cu_seqlen, # pylint: disable=unused-argument
rng_state, # pylint: disable=unused-argument seed, # pylint: disable=unused-argument
*, *,
attn_bias_type, # pylint: disable=unused-argument attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument attn_mask_type, # pylint: disable=unused-argument
...@@ -2020,8 +2043,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2020,8 +2043,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
) )
@staticmethod @staticmethod
def lowering(ctx, qkv, bias, cu_seqlen, rng_state, *, attn_bias_type, attn_mask_type, def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
scaling_factor, dropout_probability, is_training): dropout_probability, is_training):
""" """
Self fused attention max seqlen 512 fwd lowering rules Self fused attention max seqlen 512 fwd lowering rules
""" """
...@@ -2036,8 +2059,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2036,8 +2059,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type) ir_cu_seqlen_type = ir.RankedTensorType(cu_seqlen.type)
ir_cu_seqlen_shape = ir_cu_seqlen_type.shape ir_cu_seqlen_shape = ir_cu_seqlen_type.shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type) ir_seed_type = ir.RankedTensorType(seed.type)
ir_rng_state_shape = ir_rng_state_type.shape ir_seed_shape = ir_seed_type.shape
batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape batch, max_seqlen, nqkv, num_head, head_dim = ir_qkv_shape
assert nqkv == 3 assert nqkv == 3
...@@ -2049,8 +2072,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2049,8 +2072,8 @@ class SelfFusedAttnMax512FwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type), ir.RankedTensorType.get(output_shape, ir_qkv_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type) ir.RankedTensorType.get(softmax_aux_shape, ir_qkv_type.element_type)
] ]
operands = [qkv, bias, cu_seqlen, rng_state] operands = [qkv, bias, cu_seqlen, seed]
operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_rng_state_shape] operand_shapes = [ir_qkv_shape, ir_bias_shape, ir_cu_seqlen_shape, ir_seed_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
...@@ -2069,23 +2092,22 @@ _self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimit ...@@ -2069,23 +2092,22 @@ _self_fused_attn_max_512_fwd_p = register_primitive(SelfFusedAttnMax512FwdPrimit
def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray, def self_fused_attn_max_512_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, cu_seqlen: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, scaling_factor: float, attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
""" """
Wrapper for TE self fused attention max seqlen 512 fwd Wrapper for TE self fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
# Jax can't bind None, create a dummy tensor for None seed = _check_seed(seed, dropout_probability, is_training)
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
if bias is None: if bias is None:
assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS assert attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
bias = jnp.zeros(0, dtype=qkv.dtype) bias = jnp.zeros(0, dtype=qkv.dtype)
return _self_fused_attn_max_512_fwd_p.bind(qkv, return _self_fused_attn_max_512_fwd_p.bind(qkv,
bias, bias,
cu_seqlen, cu_seqlen,
rng_state, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -2161,6 +2183,9 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2161,6 +2183,9 @@ class SelfFusedAttnMax512BwdPrimitive(BasePrimitive):
operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape] operand_shapes = [ir_qkv_shape, ir_softmax_aux_shape, ir_doutput_shape, ir_cu_seqlen_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability, batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
...@@ -2208,7 +2233,7 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2208,7 +2233,7 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
kv, kv,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
rng_state, # pylint: disable=unused-argument seed, # pylint: disable=unused-argument
*, *,
attn_bias_type, # pylint: disable=unused-argument attn_bias_type, # pylint: disable=unused-argument
attn_mask_type, # pylint: disable=unused-argument attn_mask_type, # pylint: disable=unused-argument
...@@ -2243,8 +2268,8 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2243,8 +2268,8 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
) )
@staticmethod @staticmethod
def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state, *, attn_bias_type, def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type,
attn_mask_type, scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
""" """
Cross fused attention max seqlen 512 fwd lowering rules Cross fused attention max seqlen 512 fwd lowering rules
""" """
...@@ -2260,8 +2285,8 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2260,8 +2285,8 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape ir_q_cu_seqlen_shape = ir.RankedTensorType(q_cu_seqlen.type).shape
ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape ir_kv_cu_seqlen_shape = ir.RankedTensorType(kv_cu_seqlen.type).shape
ir_rng_state_type = ir.RankedTensorType(rng_state.type) ir_seed_type = ir.RankedTensorType(seed.type)
ir_rng_state_shape = ir_rng_state_type.shape ir_seed_shape = ir_seed_type.shape
batch, q_max_seqlen, num_head, head_dim = ir_q_shape batch, q_max_seqlen, num_head, head_dim = ir_q_shape
kv_max_seqlen = ir_kv_shape[1] kv_max_seqlen = ir_kv_shape[1]
...@@ -2273,9 +2298,9 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive): ...@@ -2273,9 +2298,9 @@ class CrossFusedAttnMax512FwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(output_shape, ir_q_type.element_type), ir.RankedTensorType.get(output_shape, ir_q_type.element_type),
ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type) ir.RankedTensorType.get(softmax_aux_shape, ir_q_type.element_type)
] ]
operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, rng_state] operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed]
operand_shapes = [ operand_shapes = [
ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_rng_state_shape ir_q_shape, ir_kv_shape, ir_q_cu_seqlen_shape, ir_kv_cu_seqlen_shape, ir_seed_shape
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
...@@ -2296,7 +2321,7 @@ _cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrim ...@@ -2296,7 +2321,7 @@ _cross_fused_attn_max_512_fwd_p = register_primitive(CrossFusedAttnMax512FwdPrim
def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray, def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: jnp.ndarray,
kv_cu_seqlen: jnp.ndarray, rng_state: jnp.ndarray, kv_cu_seqlen: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
scaling_factor: float, dropout_probability: float, scaling_factor: float, dropout_probability: float,
is_training: bool): is_training: bool):
...@@ -2304,14 +2329,13 @@ def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: j ...@@ -2304,14 +2329,13 @@ def cross_fused_attn_max_512_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_cu_seqlen: j
Wrapper for TE cross fused attention max seqlen 512 fwd Wrapper for TE cross fused attention max seqlen 512 fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
""" """
# Jax can't bind None, create a dummy tensor for None seed = _check_seed(seed, dropout_probability, is_training)
if rng_state is None:
rng_state = jnp.zeros(2, dtype=jnp.int32)
return _cross_fused_attn_max_512_fwd_p.bind(q, return _cross_fused_attn_max_512_fwd_p.bind(q,
kv, kv,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
rng_state, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -2391,6 +2415,9 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive): ...@@ -2391,6 +2415,9 @@ class CrossFusedAttnMax512BwdPrimitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# the dropout elements are encoded in the forward auxiliary tensor
# so seed is not needed in backward
opaque = transformer_engine_jax.pack_fused_attn_descriptor( opaque = transformer_engine_jax.pack_fused_attn_descriptor(
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
......
...@@ -749,7 +749,7 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -749,7 +749,7 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
void *qkv = buffers[0]; void *qkv = buffers[0];
void *bias = buffers[1]; void *bias = buffers[1];
void *cu_seqlens = buffers[2]; void *cu_seqlens = buffers[2];
void *rng_state = buffers[3]; void *seed = buffers[3];
// output // output
void *output = buffers[4]; void *output = buffers[4];
...@@ -778,30 +778,37 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -778,30 +778,37 @@ void SelfFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto cu_seqlens_tensor = auto cu_seqlens_tensor =
TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, &aux_output_tensors, cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.is_training, descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
descriptor.mask_type, query_workspace_tensor.data(), stream); query_workspace_tensor.data(), stream);
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
size_t workspace_size = // fused attn workspace + workspace for rng_state
auto plan_workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
...@@ -907,7 +914,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -907,7 +914,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
void *kv = buffers[1]; void *kv = buffers[1];
void *q_cu_seqlens = buffers[2]; void *q_cu_seqlens = buffers[2];
void *kv_cu_seqlens = buffers[3]; void *kv_cu_seqlens = buffers[3];
void *rng_state = buffers[4]; void *seed = buffers[4];
// output // output
void *output = buffers[5]; void *output = buffers[5];
...@@ -939,7 +946,8 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -939,7 +946,8 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(q_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor = auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32); TensorWrapper(kv_cu_seqlens, std::vector<size_t>{batch + 1}, DType::kInt32);
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{1}, DType::kInt64);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -949,7 +957,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -949,7 +957,7 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, descriptor.dropout_probability, descriptor.scaling_factor, descriptor.dropout_probability,
NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type, NVTE_QKV_Layout::NVTE_KV_INTERLEAVED, descriptor.bias_type, descriptor.mask_type,
query_workspace_tensor.data(), stream); query_workspace_tensor.data(), stream);
...@@ -957,13 +965,19 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char ...@@ -957,13 +965,19 @@ void CrossFusedAttnMax512Forward(cudaStream_t stream, void **buffers, const char
auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(aux_output_tensors.tensors[0]);
output_s->data.dptr = softmax_aux; output_s->data.dptr = softmax_aux;
size_t workspace_size = // fused attn workspace + workspace for rng_state
auto plan_workspace_size =
query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype()); query_workspace_tensor.shape().data[0] * typeToSize(query_workspace_tensor.dtype());
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size); auto rng_workspace_size = 2 * sizeof(int64_t);
auto total_workspace_size = plan_workspace_size + rng_workspace_size;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
auto workspace_tensor = auto workspace_tensor =
TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype()); TensorWrapper(workspace, query_workspace_tensor.shape(), query_workspace_tensor.dtype());
auto rng_state = static_cast<uint8_t *>(workspace) + plan_workspace_size;
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, stream);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
......
...@@ -32,5 +32,23 @@ int GetDeviceComputeCapability(int gpu_id) { ...@@ -32,5 +32,23 @@ int GetDeviceComputeCapability(int gpu_id) {
return gpu_arch; return gpu_arch;
} }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream) {
constexpr int threads_per_cta = 128;
const size_t increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -21,6 +21,9 @@ namespace jax { ...@@ -21,6 +21,9 @@ namespace jax {
int GetCudaRuntimeVersion(); int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id); int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, cudaStream_t stream);
class cublasLtMetaManager { class cublasLtMetaManager {
public: public:
static cublasLtMetaManager &Instance() { static cublasLtMetaManager &Instance() {
...@@ -93,6 +96,27 @@ class cudaDevicePropertiesManager { ...@@ -93,6 +96,27 @@ class cudaDevicePropertiesManager {
cudaDeviceProp prop_; cudaDeviceProp prop_;
}; };
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -11,6 +11,7 @@ import os ...@@ -11,6 +11,7 @@ import os
from typing import Any, Callable, Optional, Sequence, Tuple, Union from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings import warnings
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
...@@ -182,9 +183,8 @@ def core_attention(query: Array, ...@@ -182,9 +183,8 @@ def core_attention(query: Array,
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1 # TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
...@@ -384,7 +384,7 @@ class MultiHeadAttention(nn.Module): ...@@ -384,7 +384,7 @@ class MultiHeadAttention(nn.Module):
fused_attn_supported_seqlen = [128, 256, 384, 512] fused_attn_supported_seqlen = [128, 256, 384, 512]
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0"))
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \ use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \ canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \ q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn
...@@ -397,9 +397,6 @@ class MultiHeadAttention(nn.Module): ...@@ -397,9 +397,6 @@ class MultiHeadAttention(nn.Module):
f"but got {self.transpose_batch_sequence}, " f"but got {self.transpose_batch_sequence}, "
if not self.fuse_qkv: if not self.fuse_qkv:
reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, " reason += f"fuse_qkv=True is required but got {self.fuse_qkv}, "
if self.dropout_rate != 0:
# TODO(rewang): add dropout support
reason += f"no dropout is required but got dropout_rate={self.dropout_rate}, "
if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]: if canonicalize_dtype not in [jnp.bfloat16, jnp.float16]:
reason += f"dtype in [BF16, FP16] is required " \ reason += f"dtype in [BF16, FP16] is required " \
f"but got dtype={canonicalize_dtype}, " f"but got dtype={canonicalize_dtype}, "
...@@ -583,6 +580,12 @@ class MultiHeadAttention(nn.Module): ...@@ -583,6 +580,12 @@ class MultiHeadAttention(nn.Module):
assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv) assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
assert not self.transpose_batch_sequence assert not self.transpose_batch_sequence
seed = None
if dropout_rng is not None:
seed = jax.random.split(dropout_rng, len(jax.devices()))
# ensure the old key never used
del dropout_rng
# TODO(rewang): make it configurable for pre_scale_bias # TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
...@@ -607,7 +610,7 @@ class MultiHeadAttention(nn.Module): ...@@ -607,7 +610,7 @@ class MultiHeadAttention(nn.Module):
x = self_fused_attn(qkv_proj, x = self_fused_attn(qkv_proj,
bias, bias,
mask, mask,
dropout_rng, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
...@@ -626,7 +629,7 @@ class MultiHeadAttention(nn.Module): ...@@ -626,7 +629,7 @@ class MultiHeadAttention(nn.Module):
x = cross_fused_attn(query, x = cross_fused_attn(query,
kv_proj, kv_proj,
mask, mask,
dropout_rng, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
......
...@@ -46,7 +46,7 @@ class AttnMaskType(Enum): ...@@ -46,7 +46,7 @@ class AttnMaskType(Enum):
def self_fused_attn(qkv: jnp.ndarray, def self_fused_attn(qkv: jnp.ndarray,
bias: jnp.ndarray, bias: jnp.ndarray,
mask: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
scaling_factor: float, scaling_factor: float,
...@@ -63,7 +63,7 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -63,7 +63,7 @@ def self_fused_attn(qkv: jnp.ndarray,
output = _self_fused_attn_max_512(qkv, output = _self_fused_attn_max_512(qkv,
bias, bias,
mask, mask,
rng_state, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -73,13 +73,13 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -73,13 +73,13 @@ def self_fused_attn(qkv: jnp.ndarray,
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
inputs = [qkv, bias, mask, rng_state] inputs = [qkv, bias, mask, seed]
batch, seqlen, _, num_head, head_dim = qkv.shape batch, seqlen, _, num_head, head_dim = qkv.shape
output_shape = [batch, seqlen, num_head, head_dim] output_shape = [batch, seqlen, num_head, head_dim]
sharding_meta = get_fused_attn_sharding_meta( sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape], sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
dp_dims=([0, None, 0, None], [0]), dp_dims=([0, None, 0, 0], [0]),
tp_dims=([3, 1, None, None], [2]), tp_dims=([3, 1, None, 0], [2]),
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name) tp_axis_name=tp_axis_name)
...@@ -104,13 +104,13 @@ def self_fused_attn(qkv: jnp.ndarray, ...@@ -104,13 +104,13 @@ def self_fused_attn(qkv: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType, seed: jnp.ndarray, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, scaling_factor: float, attn_mask_type: AttnMaskType, scaling_factor: float,
dropout_probability: float, is_training: bool): dropout_probability: float, is_training: bool):
output, _ = _self_fused_attn_max_512_fwd(qkv, output, _ = _self_fused_attn_max_512_fwd(qkv,
bias, bias,
mask, mask,
rng_state, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -119,7 +119,7 @@ def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndar ...@@ -119,7 +119,7 @@ def _self_fused_attn_max_512(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndar
return output return output
def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, attn_mask_type, def _self_fused_attn_max_512_fwd(qkv, bias, mask, seed, attn_bias_type, attn_mask_type,
scaling_factor, dropout_probability, is_training): scaling_factor, dropout_probability, is_training):
seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
...@@ -129,7 +129,7 @@ def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, att ...@@ -129,7 +129,7 @@ def _self_fused_attn_max_512_fwd(qkv, bias, mask, rng_state, attn_bias_type, att
output, softmax_aux = self_fused_attn_max_512_fwd(qkv, output, softmax_aux = self_fused_attn_max_512_fwd(qkv,
bias, bias,
cu_seqlen, cu_seqlen,
rng_state, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -163,7 +163,7 @@ _self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_m ...@@ -163,7 +163,7 @@ _self_fused_attn_max_512.defvjp(_self_fused_attn_max_512_fwd, _self_fused_attn_m
def cross_fused_attn(q: jnp.ndarray, def cross_fused_attn(q: jnp.ndarray,
kv: jnp.ndarray, kv: jnp.ndarray,
mask: jnp.ndarray, mask: jnp.ndarray,
rng_state: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
scaling_factor: float, scaling_factor: float,
...@@ -180,7 +180,7 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -180,7 +180,7 @@ def cross_fused_attn(q: jnp.ndarray,
output = _cross_fused_attn_max_512(q, output = _cross_fused_attn_max_512(q,
kv, kv,
mask, mask,
rng_state, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -190,7 +190,7 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -190,7 +190,7 @@ def cross_fused_attn(q: jnp.ndarray,
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
inputs = [q, kv, mask, rng_state] inputs = [q, kv, mask, seed]
output_shape = q.shape output_shape = q.shape
sharding_meta = get_fused_attn_sharding_meta( sharding_meta = get_fused_attn_sharding_meta(
sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape], sharding_type, [x.shape if x is not None else None for x in inputs], [output_shape],
...@@ -219,15 +219,14 @@ def cross_fused_attn(q: jnp.ndarray, ...@@ -219,15 +219,14 @@ def cross_fused_attn(q: jnp.ndarray,
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray, seed: jnp.ndarray,
rng_state: jnp.ndarray, attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType,
attn_mask_type: AttnMaskType, scaling_factor: float, scaling_factor: float, dropout_probability: float, is_training: bool):
dropout_probability: float, is_training: bool):
output, _ = _cross_fused_attn_max_512_fwd(q, output, _ = _cross_fused_attn_max_512_fwd(q,
kv, kv,
mask, mask,
rng_state, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -236,8 +235,8 @@ def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray ...@@ -236,8 +235,8 @@ def _cross_fused_attn_max_512(q: jnp.ndarray, kv: jnp.ndarray, mask: jnp.ndarray
return output return output
def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_mask_type, def _cross_fused_attn_max_512_fwd(q, kv, mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
scaling_factor, dropout_probability, is_training): dropout_probability, is_training):
q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32) q_seqlen = jnp.sum(mask[:, :, :, 0] == 0, axis=(-1, -2), dtype=jnp.int32)
q_cu_seqlen = jnp.cumsum(q_seqlen) q_cu_seqlen = jnp.cumsum(q_seqlen)
...@@ -251,7 +250,7 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_m ...@@ -251,7 +250,7 @@ def _cross_fused_attn_max_512_fwd(q, kv, mask, rng_state, attn_bias_type, attn_m
kv, kv,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
rng_state, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment