Unverified Commit 94de051f authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C++/PyTorch] Add alibi_slopes support (#608)



* test alibi between fa and fu
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move alibi slopes and bias to global to avoid repeating calculation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix alibi slopes/bias generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix _is_flash_attention_supported to allow alibi type
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable padding mask when alibi is used for fused attn arbi backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support for custom [n_heads] alibi_slopes in flash, fused, unfused attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove alibi_type=none tests as they are unnecessary
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to 1.0.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bias/dbias shape to allow b,1/1,h/b,h in arbi backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak tests for arbi post_scale_bias [1,h,s,s] or alibi_slopes [n_heads]
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bias/dbias shape in max512 backend - incomplete
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove max512 changes from last commit and disable max512 (and arbi temporarily) for [b, h, s, s]; pending cuDNN backend support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and tweak backend selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace || with () in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bias shape for max512 backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* combine slopes/bias generation to one function get_alibi() and fix alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix PR557 bugs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* encapsulate global alibi tensors into a dict cache
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce alibi slopes test size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn-frontend 1.0.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use dBias shape to define bias_b/bias_h because jax materializes dBias rather than Bias in bwd abstract
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent da30634a
...@@ -56,6 +56,8 @@ torch.cuda.manual_seed(seed) ...@@ -56,6 +56,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None: def reset_rng_states() -> None:
"""Revert back to initial RNG state""" """Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state) torch.set_rng_state(_cpu_rng_state)
...@@ -81,6 +83,7 @@ class ModelConfig: ...@@ -81,6 +83,7 @@ class ModelConfig:
dropout_p: float, dropout_p: float,
attn_mask_type: str, attn_mask_type: str,
attn_bias_type: str, attn_bias_type: str,
alibi_type: str = "none",
num_layers: int = 1, num_layers: int = 1,
): ):
self.batch_size = batch_size self.batch_size = batch_size
...@@ -94,6 +97,7 @@ class ModelConfig: ...@@ -94,6 +97,7 @@ class ModelConfig:
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers self.num_layers = num_layers
...@@ -167,7 +171,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: ...@@ -167,7 +171,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration""" """Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0): if get_device_compute_capability() < (8, 0):
return False return False
if config.attn_bias_type != "no_bias": if config.attn_bias_type not in ["no_bias", "alibi"]:
return False return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available(): if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False return False
...@@ -283,18 +287,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ...@@ -283,18 +287,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd): for i,_ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd): for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2: if fused_attn_supported and len(fused_attn_backend) == 2:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd): for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
...@@ -382,6 +394,21 @@ def test_dpa_sliding_window(dtype, model_configs, model): ...@@ -382,6 +394,21 @@ def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention""" """Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True) test_dot_product_attention(dtype, model_configs, model, False, True, None, True)
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
qkv_layouts = [ qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
...@@ -477,9 +504,17 @@ def _run_dot_product_attention( ...@@ -477,9 +504,17 @@ def _run_dot_product_attention(
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda")) attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
if swa: if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv) window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
else: else:
window_size, attention_mask = None, None window_size, attention_mask = None, None
alibi_slopes = None
if config.attn_bias_type == "alibi":
if config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Create input tensors # Create input tensors
dim_to_num = { dim_to_num = {
'b' : config.batch_size, 'b' : config.batch_size,
...@@ -570,6 +605,7 @@ def _run_dot_product_attention( ...@@ -570,6 +605,7 @@ def _run_dot_product_attention(
checkpoint_core_attention=ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias, core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True) fast_zero_fill=True)
out.backward(out_grad) out.backward(out_grad)
...@@ -583,6 +619,8 @@ model_configs_te_layer = { ...@@ -583,6 +619,8 @@ model_configs_te_layer = {
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
} }
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
...@@ -654,12 +692,18 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f ...@@ -654,12 +692,18 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
) )
if unfused_attn_supported and fused_attn_supported: if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported: if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols) torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported: if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
...@@ -758,28 +802,10 @@ def _run_transformer_layer( ...@@ -758,28 +802,10 @@ def _run_transformer_layer(
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias # Create bias
if config.attn_bias_type == 'no_bias': bias = None
bias = None
if config.attn_bias_type == 'post_scale_bias': if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv, bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda") dtype=dtype, device="cuda")
elif config.attn_bias_type == 'alibi':
if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
config.attn_bias_type = 'post_scale_bias'
n = 2 ** math.floor(math.log2(config.num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
d = c - torch.transpose(c, 0, 1)
bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
for i in range(config.num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=dtype, device="cuda")
else:
bias = None
# Create RoPE # Create RoPE
rotary_pos_emb = None rotary_pos_emb = None
...@@ -825,6 +851,12 @@ def _run_transformer_layer( ...@@ -825,6 +851,12 @@ def _run_transformer_layer(
.to(dtype=dtype, device="cuda") .to(dtype=dtype, device="cuda")
) )
# Create ALiBi slopes
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Run a forward and backward pass # Run a forward and backward pass
out = block(inp, out = block(inp,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -832,7 +864,8 @@ def _run_transformer_layer( ...@@ -832,7 +864,8 @@ def _run_transformer_layer(
checkpoint_core_attention=False, checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias) core_attention_bias=bias,
alibi_slopes=alibi_slopes)
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
......
...@@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI || (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
&& sm_arch_ == 90) && sm_arch_ == 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90)))) && sm_arch_ == 90))))
......
...@@ -49,6 +49,7 @@ namespace transformer_engine { ...@@ -49,6 +49,7 @@ namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
int64_t bias_b, int64_t bias_h,
bool is_training, float scaling_factor, float dropout_probability, bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
...@@ -154,8 +155,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -154,8 +155,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_bias) { if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes() bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias") .set_name("bias")
.set_dim({1, h, s_q, s_kv}) .set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1})); .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_options.set_bias(bias); sdpa_options.set_bias(bias);
} }
...@@ -293,6 +294,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -293,6 +294,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose, void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose,
...@@ -417,12 +419,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -417,12 +419,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bias) { if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes() bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias") .set_name("bias")
.set_dim({1, h, s_q, s_kv}) .set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1})); .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes() dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias") .set_name("dBias")
.set_dim({1, h, s_q, s_kv}) .set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1})); .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_backward_options.set_bias(bias); sdpa_backward_options.set_bias(bias);
sdpa_backward_options.set_dbias(dBias); sdpa_backward_options.set_dbias(dBias);
} }
...@@ -590,7 +592,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -590,7 +592,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = input_Bias->data.dptr; void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr;
...@@ -608,7 +617,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -608,7 +617,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen, max_seqlen}; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
...@@ -644,7 +653,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -644,7 +653,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_attn_heads, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_attn_heads,
max_seqlen, max_seqlen, head_dim, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
...@@ -698,10 +707,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -698,10 +707,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr; void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr; void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr; devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
} }
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrdQKV = output_dQKV->data.dptr; void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV; void *devPtrdQ = devPtrdQKV;
...@@ -720,7 +734,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea ...@@ -720,7 +734,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
max_seqlen, max_seqlen, head_dim, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
...@@ -767,7 +781,14 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -767,7 +781,14 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void *devPtrBias = input_Bias->data.dptr; void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
...@@ -787,7 +808,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -787,7 +808,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
...@@ -823,8 +844,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -823,8 +844,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
head_dim, is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
...@@ -879,8 +900,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -879,8 +900,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr; void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr; void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr; devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
} }
void *devPtrdQ = output_dQ->data.dptr; void *devPtrdQ = output_dQ->data.dptr;
...@@ -890,7 +917,6 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -890,7 +917,6 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
...@@ -902,8 +928,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -902,8 +928,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
head_dim, attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
...@@ -944,7 +970,14 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -944,7 +970,14 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrV = input_V->data.dptr; void *devPtrV = input_V->data.dptr;
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
void *devPtrBias = input_Bias->data.dptr; void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
...@@ -962,7 +995,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -962,7 +995,7 @@ void fused_attn_arbitrary_seqlen_fwd(
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
...@@ -998,8 +1031,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -998,8 +1031,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups, fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
head_dim, is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
...@@ -1045,8 +1078,14 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1045,8 +1078,14 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
void* devPtrO = input_O->data.dptr; void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr; void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr; void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr; devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
} }
void *devPtrdQ = output_dQ->data.dptr; void *devPtrdQ = output_dQ->data.dptr;
...@@ -1054,7 +1093,6 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1054,7 +1093,6 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
void *devPtrdV = output_dV->data.dptr; void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr; void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
...@@ -1066,9 +1104,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t ...@@ -1066,9 +1104,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups, fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
head_dim, attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrDropoutSeed, devPtrDropoutOffset,
......
...@@ -71,6 +71,15 @@ if _flash_attn_version >= _flash_attn_version_required: ...@@ -71,6 +71,15 @@ if _flash_attn_version >= _flash_attn_version_required:
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
"_max_seqlen_q": None,
"_max_seqlen_kv": None,
"_alibi_bias": None,
"_alibi_slopes_require_update": False,
"_alibi_bias_require_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
...@@ -126,32 +135,70 @@ def get_alibi( ...@@ -126,32 +135,70 @@ def get_alibi(
num_heads: int, num_heads: int,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
) -> torch.Tensor: alibi_slopes: Optional[torch.Tensor] = None,
""" bias_dtype: Optional[torch.dtype] = None,
Generate ALiBi bias in the shape of [1, num_heads, max_seqlen_q, max_seqlen_kv]. ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
n = 2 ** math.floor(math.log2(num_heads)) Parameters
m_0 = 2.0 ** (-8.0 / n) ----------
m = torch.pow(m_0, torch.arange(1, 1 + n)) num_heads: int
Number of heads.
if n < num_heads: max_seqlen_q: int
m_hat_0 = 2.0 ** (-4.0 / n) Maximum sequence length for queries.
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) max_seqlen_kv: int
m = torch.cat([m, m_hat]) Maximum sequence length for keys and values.
alibi_slopes: Optional[torch.Tensor], default = `None`
a = torch.ones(max_seqlen_q, max_seqlen_kv) Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
b = torch.triu(a,diagonal=1) bias_dtype: Optional[torch.dtype], default = `None`
c = b.cumsum(dim=-1) Dtype of the generated ALiBi bias. If None, use torch.float32.
bb = torch.tril(a,diagonal=-1)
cc = bb.cumsum(dim=0)
d = c - cc
bias = d.repeat(1, num_heads, 1, 1)
for i in range(num_heads): Returns
bias[0,i,:,:] = m[i] * bias[0,i,:,:] ----------
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
`alibi_slopes` is in [batch_size, num_heads], then the bias is in
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
"""
global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]:
if alibi_slopes is not None:
_alibi_cache["_alibi_slopes"] = alibi_slopes
else:
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
_alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
_alibi_cache["_num_heads"] = num_heads
_alibi_cache["_alibi_slopes_require_update"] = False
if _alibi_cache["_alibi_bias_require_update"]:
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
bias = torch.arange(
1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(
1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1)
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False
return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
bias = bias.to(dtype=torch.float32, device="cuda")
return bias
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
""" """
...@@ -1281,6 +1328,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -1281,6 +1328,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
...@@ -1350,8 +1398,6 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -1350,8 +1398,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
elif core_attention_bias_type == "pre_scale_bias": elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!" assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size(1, *output_size[1:])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
matmul_result = torch.bmm( matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
...@@ -1364,10 +1410,9 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -1364,10 +1410,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
elif core_attention_bias_type in ["post_scale_bias", "alibi"]: elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
if core_attention_bias_type == "post_scale_bias": if core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!" assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
core_attention_bias = get_alibi(output_size[1], output_size[2], output_size[3]) _, core_attention_bias = get_alibi(
output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes)
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_result, matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
...@@ -2342,6 +2387,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2342,6 +2387,7 @@ class DotProductAttention(torch.nn.Module):
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.layer_number = 1 if layer_number is None else layer_number
self.cp_group = cp_group self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
...@@ -2472,10 +2518,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -2472,10 +2518,10 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -2553,11 +2599,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -2553,11 +2599,7 @@ class DotProductAttention(torch.nn.Module):
`arbitrary`}, default = `None`. Type of attention mask passed into `arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent. softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention. Sliding window size for local attention.
alibi_slopes: Optional[torch.Tensor], default = `None`
An fp32 bias of shape (nheads,) or (batch_size, nheads)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
checkpoint_core_attention : bool, default = `False` checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
...@@ -2568,6 +2610,10 @@ class DotProductAttention(torch.nn.Module): ...@@ -2568,6 +2610,10 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types. It should be 'None' for 'no_bias' and 'alibi' bias types.
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not. Whether to use the fast path to set output tensors to 0 or not.
""" """
...@@ -2652,6 +2698,11 @@ class DotProductAttention(torch.nn.Module): ...@@ -2652,6 +2698,11 @@ class DotProductAttention(torch.nn.Module):
# The following section filters out some backends based on # The following section filters out some backends based on
# certain asserts before executing the forward pass. # certain asserts before executing the forward pass.
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
# Filter: Input type. # Filter: Input type.
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
...@@ -2680,10 +2731,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -2680,10 +2731,6 @@ class DotProductAttention(torch.nn.Module):
) )
use_flash_attention = False use_flash_attention = False
# Filter: bias.
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
context_parallel = (self.cp_group is not None and \ context_parallel = (self.cp_group is not None and \
get_distributed_world_size(self.cp_group) != 1) get_distributed_world_size(self.cp_group) != 1)
...@@ -2694,11 +2741,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -2694,11 +2741,6 @@ class DotProductAttention(torch.nn.Module):
if (not _flash_attn_2_3_plus) or context_parallel: if (not _flash_attn_2_3_plus) or context_parallel:
use_flash_attention = False use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
# Filter: Attention mask type. # Filter: Attention mask type.
# attn_mask_type(s) | supported backends # attn_mask_type(s) | supported backends
# ------------------------------------------------ # ------------------------------------------------
...@@ -2714,12 +2756,47 @@ class DotProductAttention(torch.nn.Module): ...@@ -2714,12 +2756,47 @@ class DotProductAttention(torch.nn.Module):
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
use_unfused_attention = False use_unfused_attention = False
# Filter: bias.
global _alibi_cache
if alibi_slopes is not None:
assert (core_attention_bias_type == "alibi"
), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type == "alibi":
assert (core_attention_bias is None
), "core_attention_bias must be None when core_attention_bias_type is alibi!"
if (_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_alibi_slopes"] is None):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None:
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype)
if (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]):
# remove this line when cuDNN adds bwd support for [b, 1, s, s] and [b, h, s, s]
use_fused_attention = False
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention: if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype], TE_DType[key_layer.dtype],
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type], AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
self.attention_dropout, self.attention_dropout,
query_layer.shape[-2], # num_attn_heads query_layer.shape[-2], # num_attn_heads
...@@ -2736,13 +2813,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -2736,13 +2813,6 @@ class DotProductAttention(torch.nn.Module):
(not context_parallel or \ (not context_parallel or \
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"])) fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
# Filter: Alibi slopes
if alibi_slopes is not None:
use_fused_attention = False
assert (
use_flash_attention
), "Alibi slopes bias is only supported in the FlashAttention backend."
# Filter: determinism. # Filter: determinism.
# backend | deterministic # backend | deterministic
# --------------------------------------------------------- # ---------------------------------------------------------
...@@ -2771,6 +2841,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2771,6 +2841,9 @@ class DotProductAttention(torch.nn.Module):
if use_flash_attention: if use_flash_attention:
if _NVTE_DEBUG: if _NVTE_DEBUG:
print("[DotProductAttention]: using flash-attn",_flash_attn_version) print("[DotProductAttention]: using flash-attn",_flash_attn_version)
if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
return self.flash_attention(query_layer, return self.flash_attention(query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -2803,8 +2876,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2803,8 +2876,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
...@@ -2821,8 +2894,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2821,8 +2894,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
...@@ -2855,7 +2928,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2855,7 +2928,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type = attn_mask_type, attn_mask_type = attn_mask_type,
attention_mask = attention_mask, attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias) core_attention_bias = core_attention_bias,
alibi_slopes = alibi_slopes)
return self.unfused_attention(query_layer, return self.unfused_attention(query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -2865,7 +2939,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2865,7 +2939,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type = attn_mask_type, attn_mask_type = attn_mask_type,
attention_mask = attention_mask, attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias) core_attention_bias = core_attention_bias,
alibi_slopes = alibi_slopes)
raise Exception("No dot product attention support for the provided inputs!") raise Exception("No dot product attention support for the provided inputs!")
...@@ -3279,6 +3354,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3279,6 +3354,7 @@ class MultiheadAttention(torch.nn.Module):
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
""" """
...@@ -3334,6 +3410,10 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3334,6 +3410,10 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types. It should be 'None' for 'no_bias' and 'alibi' bias types.
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
""" """
...@@ -3561,6 +3641,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3561,6 +3641,7 @@ class MultiheadAttention(torch.nn.Module):
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
) )
......
...@@ -184,9 +184,6 @@ def fused_attn_fwd_qkvpacked( ...@@ -184,9 +184,6 @@ def fused_attn_fwd_qkvpacked(
if attn_bias_type not in ["no_bias", "alibi"]: if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = qkv.size(2) if 'h3d' in qkv_layout else qkv.size(3)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen])
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype assert (attn_bias.dtype == qkv.dtype
), "attn_bias tensor must be in the same dtype as qkv." ), "attn_bias tensor must be in the same dtype as qkv."
...@@ -479,9 +476,6 @@ def fused_attn_fwd_kvpacked( ...@@ -479,9 +476,6 @@ def fused_attn_fwd_kvpacked(
if attn_bias_type not in ["no_bias", "alibi"]: if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv." ), "attn_bias tensor must be in the same dtype as q and kv."
...@@ -784,9 +778,6 @@ def fused_attn_fwd( ...@@ -784,9 +778,6 @@ def fused_attn_fwd(
if attn_bias_type not in ["no_bias", "alibi"]: if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv." ), "attn_bias tensor must be in the same dtype as q and kv."
......
...@@ -286,14 +286,6 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -286,14 +286,6 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// create output tensor dQKV // create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV); at::Tensor dQKV = torch::empty_like(QKV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
...@@ -358,6 +350,23 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -358,6 +350,23 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
} }
// create dBias the same shape as Bias
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS)
&& (bias_type != NVTE_ALIBI)) {
if (nvte_aux_tensor_pack.size >= 2) {
std::vector<int64_t> bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec());
dBias = torch::empty(bias_shape, options);
te_dBias = makeTransformerEngineTensor(dBias);
} else {
dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
}
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); auto cu_seqlens_sizes = cu_seqlens.sizes().vec();
std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()};
...@@ -629,14 +638,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -629,14 +638,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
at::Tensor dQ = torch::empty_like(Q); at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV); at::Tensor dKV = torch::empty_like(KV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
...@@ -721,6 +722,23 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -721,6 +722,23 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
} }
// create dBias the same shape as Bias
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS)
&& (bias_type != NVTE_ALIBI)) {
if (nvte_aux_tensor_pack.size >= 2) {
std::vector<int64_t> bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec());
dBias = torch::empty(bias_shape, options);
te_dBias = makeTransformerEngineTensor(dBias);
} else {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
}
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
...@@ -990,6 +1008,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -990,6 +1008,9 @@ std::vector<at::Tensor> fused_attn_bwd(
std::vector<size_t> k_shape{k_sizes.begin(), k_sizes.end()}; std::vector<size_t> k_shape{k_sizes.begin(), k_sizes.end()};
auto v_sizes = V.sizes().vec(); auto v_sizes = V.sizes().vec();
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()}; std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dQ; at::Tensor dQ;
...@@ -1055,22 +1076,10 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1055,22 +1076,10 @@ std::vector<at::Tensor> fused_attn_bwd(
NVTE_ERROR("QKV layout not supported!"); NVTE_ERROR("QKV layout not supported!");
} }
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(Q.size(-2)),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero if (set_zero
&& ((h_q * d) % block_size == 0) && ((h_q * d) % block_size == 0)
&& ((h_kv * d) % block_size == 0) && ((h_kv * d) % block_size == 0)
...@@ -1165,6 +1174,23 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1165,6 +1174,23 @@ std::vector<at::Tensor> fused_attn_bwd(
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
} }
// create dBias the same shape as Bias
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS)
&& (bias_type != NVTE_ALIBI)) {
if (nvte_aux_tensor_pack.size >= 2) {
std::vector<int64_t> bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec());
dBias = torch::empty(bias_shape, options);
te_dBias = makeTransformerEngineTensor(dBias);
} else {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
}
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
......
...@@ -525,6 +525,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -525,6 +525,7 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -583,6 +584,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -583,6 +584,10 @@ class TransformerLayer(torch.nn.Module):
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T Bias tensor for Q * K.T
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
...@@ -633,6 +638,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -633,6 +638,7 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
) )
...@@ -658,6 +664,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -658,6 +664,7 @@ class TransformerLayer(torch.nn.Module):
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill, fast_zero_fill=fast_zero_fill,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment