Unverified Commit 94de051f authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[C++/PyTorch] Add alibi_slopes support (#608)



* test alibi between fa and fu
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move alibi slopes and bias to global to avoid repeating calculation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix alibi slopes/bias generation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix _is_flash_attention_supported to allow alibi type
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* disable padding mask when alibi is used for fused attn arbi backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add support for custom [n_heads] alibi_slopes in flash, fused, unfused attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove alibi_type=none tests as they are unnecessary
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to 1.0.2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bias/dbias shape to allow b,1/1,h/b,h in arbi backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* tweak tests for arbi post_scale_bias [1,h,s,s] or alibi_slopes [n_heads]
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change bias/dbias shape in max512 backend - incomplete
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove max512 changes from last commit and disable max512 (and arbi temporarily) for [b, h, s, s]; pending cuDNN backend support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* clean up and tweak backend selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace || with () in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix bias shape for max512 backend
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* combine slopes/bias generation to one function get_alibi() and fix alibi tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix PR557 bugs
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/pytorch/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>

* encapsulate global alibi tensors into a dict cache
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce alibi slopes test size
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to cudnn-frontend 1.0.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use dBias shape to define bias_b/bias_h because jax materializes dBias rather than Bias in bwd abstract
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent da30634a
......@@ -56,6 +56,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
......@@ -81,6 +83,7 @@ class ModelConfig:
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
alibi_type: str = "none",
num_layers: int = 1,
):
self.batch_size = batch_size
......@@ -94,6 +97,7 @@ class ModelConfig:
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
......@@ -167,7 +171,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
return False
if config.attn_bias_type != "no_bias":
if config.attn_bias_type not in ["no_bias", "alibi"]:
return False
if config.num_heads != config.num_gqa_groups and not _is_flash_attention_2_available():
return False
......@@ -283,18 +287,26 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
)
if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
if _NVTE_DEBUG:
print("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i,_ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
......@@ -382,6 +394,21 @@ def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True)
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
"alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)
qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
......@@ -477,9 +504,17 @@ def _run_dot_product_attention(
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
elif "causal" in config.attn_mask_type:
window_size, attention_mask = (-1, 0), None
else:
window_size, attention_mask = None, None
alibi_slopes = None
if config.attn_bias_type == "alibi":
if config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Create input tensors
dim_to_num = {
'b' : config.batch_size,
......@@ -570,6 +605,7 @@ def _run_dot_product_attention(
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True)
out.backward(out_grad)
......@@ -583,6 +619,8 @@ model_configs_te_layer = {
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
......@@ -654,12 +692,18 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
)
if unfused_attn_supported and fused_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported:
if _NVTE_DEBUG:
print("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
......@@ -758,28 +802,10 @@ def _run_transformer_layer(
rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias
if config.attn_bias_type == 'no_bias':
bias = None
bias = None
if config.attn_bias_type == 'post_scale_bias':
bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
dtype=dtype, device="cuda")
elif config.attn_bias_type == 'alibi':
if os.environ['NVTE_FUSED_ATTN_BACKEND'] == '0':
config.attn_bias_type = 'post_scale_bias'
n = 2 ** math.floor(math.log2(config.num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
a = torch.ones(config.max_seqlen_q, config.max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
d = c - torch.transpose(c, 0, 1)
bias = d.expand(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv)
for i in range(config.num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
bias = bias.to(dtype=dtype, device="cuda")
else:
bias = None
# Create RoPE
rotary_pos_emb = None
......@@ -825,6 +851,12 @@ def _run_transformer_layer(
.to(dtype=dtype, device="cuda")
)
# Create ALiBi slopes
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
alibi_slopes = torch.randn(
config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Run a forward and backward pass
out = block(inp,
attention_mask=attention_mask,
......@@ -832,7 +864,8 @@ def _run_transformer_layer(
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias)
core_attention_bias=bias,
alibi_slopes=alibi_slopes)
loss = out.sum()
loss.backward()
......
......@@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
&& sm_arch_ == 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90))))
......
......@@ -49,6 +49,7 @@ namespace transformer_engine {
namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
int64_t bias_b, int64_t bias_h,
bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
......@@ -154,8 +155,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_options.set_bias(bias);
}
......@@ -293,6 +294,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d,
int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
void* devPtrQ, void* devPtrKTranspose, void* devPtrVTranspose,
......@@ -417,12 +419,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bias) {
bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
dBias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dBias")
.set_dim({1, h, s_q, s_kv})
.set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
.set_dim({bias_b, bias_h, s_q, s_kv})
.set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1}));
sdpa_backward_options.set_bias(bias);
sdpa_backward_options.set_dbias(dBias);
}
......@@ -590,7 +592,14 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
......@@ -608,7 +617,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen, max_seqlen};
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
......@@ -644,7 +653,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_attn_heads,
max_seqlen, max_seqlen, head_dim,
max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
......@@ -698,10 +707,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
}
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrdQKV = output_dQKV->data.dptr;
void *devPtrdQ = devPtrdQKV;
......@@ -720,7 +734,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_attn_heads,
max_seqlen, max_seqlen, head_dim,
max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
......@@ -767,7 +781,14 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
void *devPtrBias = input_Bias->data.dptr;
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
......@@ -787,7 +808,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv};
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
......@@ -823,8 +844,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, is_training, attn_scale, p_dropout, qkv_layout,
max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
......@@ -879,8 +900,14 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
}
void *devPtrdQ = output_dQ->data.dptr;
......@@ -890,7 +917,6 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -902,8 +928,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, attn_scale, p_dropout, qkv_layout,
max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
......@@ -944,7 +970,14 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrV = input_V->data.dptr;
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
void *devPtrBias = input_Bias->data.dptr;
void *devPtrBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -962,7 +995,7 @@ void fused_attn_arbitrary_seqlen_fwd(
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {1, num_attn_heads, max_seqlen_q, max_seqlen_kv};
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
......@@ -998,8 +1031,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_fwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, is_training, attn_scale, p_dropout, qkv_layout,
max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO,
devPtrDropoutSeed, devPtrDropoutOffset,
......@@ -1045,8 +1078,14 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
void* devPtrO = input_O->data.dptr;
void *devPtrdO = input_dO->data.dptr;
void *devPtrBias = nullptr;
void *devPtrdBias = nullptr;
size_t bias_b = 0;
size_t bias_h = 0;
if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) {
devPtrBias = input_Bias->data.dptr;
devPtrdBias = output_dBias->data.dptr;
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
}
void *devPtrdQ = output_dQ->data.dptr;
......@@ -1054,7 +1093,6 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrdBias = output_dBias->data.dptr;
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
......@@ -1066,9 +1104,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t
size_t workspace_size = 0;
fused_attn_arbitrary_seqlen_bwd_impl(batch, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv,
head_dim, attn_scale, p_dropout, qkv_layout,
bias_type, mask_type,
max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset,
......
......@@ -71,6 +71,15 @@ if _flash_attn_version >= _flash_attn_version_required:
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_alibi_cache = {
"_num_heads": None,
"_alibi_slopes": None,
"_max_seqlen_q": None,
"_max_seqlen_kv": None,
"_alibi_bias": None,
"_alibi_slopes_require_update": False,
"_alibi_bias_require_update": False,
}
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
......@@ -126,32 +135,70 @@ def get_alibi(
num_heads: int,
max_seqlen_q: int,
max_seqlen_kv: int,
) -> torch.Tensor:
"""
Generate ALiBi bias in the shape of [1, num_heads, max_seqlen_q, max_seqlen_kv].
alibi_slopes: Optional[torch.Tensor] = None,
bias_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
a = torch.ones(max_seqlen_q, max_seqlen_kv)
b = torch.triu(a,diagonal=1)
c = b.cumsum(dim=-1)
bb = torch.tril(a,diagonal=-1)
cc = bb.cumsum(dim=0)
d = c - cc
bias = d.repeat(1, num_heads, 1, 1)
Parameters
----------
num_heads: int
Number of heads.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
alibi_slopes: Optional[torch.Tensor], default = `None`
Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads].
bias_dtype: Optional[torch.dtype], default = `None`
Dtype of the generated ALiBi bias. If None, use torch.float32.
for i in range(num_heads):
bias[0,i,:,:] = m[i] * bias[0,i,:,:]
Returns
----------
alibi_slopes: torch.Tensor
ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads].
alibi_bias: torch.Tensor
ALiBi bias in FP32 or `bias_dtype`. If `alibi_slopes` is in [num_heads] shape,
then `alibi_bias` is in [1, num_heads, max_seqlen_q, max_seqlen_kv], and if
`alibi_slopes` is in [batch_size, num_heads], then the bias is in
[batch_size, num_heads, max_seqlen_q, max_seqlen_kv].
"""
global _alibi_cache
if _alibi_cache["_alibi_slopes_require_update"]:
if alibi_slopes is not None:
_alibi_cache["_alibi_slopes"] = alibi_slopes
else:
n = 2 ** math.floor(math.log2(num_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n))
if n < num_heads:
m_hat_0 = 2.0 ** (-4.0 / n)
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
m = torch.cat([m, m_hat])
_alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda")
_alibi_cache["_num_heads"] = num_heads
_alibi_cache["_alibi_slopes_require_update"] = False
if _alibi_cache["_alibi_bias_require_update"]:
assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!"
if _alibi_cache["_alibi_slopes"].dim() == 1:
slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1])
if _alibi_cache["_alibi_slopes"].dim() == 2:
slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1])
bias = torch.arange(
1 - max_seqlen_kv, 1, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
bias = bias - torch.arange(
1 - max_seqlen_q, 1, dtype=torch.int32, device="cuda").view(1, 1, max_seqlen_q, 1)
bias = bias.abs().mul(-1)
bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape)
_alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv
bias_dtype = torch.float32 if bias_dtype is None else bias_dtype
_alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda")
_alibi_cache["_alibi_bias_require_update"] = False
return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"]
bias = bias.to(dtype=torch.float32, device="cuda")
return bias
def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor:
"""
......@@ -1281,6 +1328,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Unfused attention fprop"""
......@@ -1350,8 +1398,6 @@ class UnfusedDotProductAttention(torch.nn.Module):
elif core_attention_bias_type == "pre_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size(1, *output_size[1:])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
matmul_result = torch.bmm(
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
......@@ -1364,10 +1410,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
elif core_attention_bias_type in ["post_scale_bias", "alibi"]:
if core_attention_bias_type == "post_scale_bias":
assert core_attention_bias is not None, "core_attention_bias should not be None!"
assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]])
), "core_attention_bias must be in [1, h, sq, skv] shape!"
if core_attention_bias_type == "alibi":
core_attention_bias = get_alibi(output_size[1], output_size[2], output_size[3])
_, core_attention_bias = get_alibi(
output_size[1], output_size[2], output_size[3], alibi_slopes=alibi_slopes)
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
......@@ -2342,6 +2387,7 @@ class DotProductAttention(torch.nn.Module):
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
self.num_attention_heads = num_attention_heads
self.layer_number = 1 if layer_number is None else layer_number
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
......@@ -2472,10 +2518,10 @@ class DotProductAttention(torch.nn.Module):
max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""
......@@ -2553,11 +2599,7 @@ class DotProductAttention(torch.nn.Module):
`arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
window_size: Optional[Tuple[int, int]], default = `None`
sliding window size for local attention.
alibi_slopes: Optional[torch.Tensor], default = `None`
An fp32 bias of shape (nheads,) or (batch_size, nheads)
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
Sliding window size for local attention.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
......@@ -2568,6 +2610,10 @@ class DotProductAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
"""
......@@ -2652,6 +2698,11 @@ class DotProductAttention(torch.nn.Module):
# The following section filters out some backends based on
# certain asserts before executing the forward pass.
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
# Filter: Input type.
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
......@@ -2680,10 +2731,6 @@ class DotProductAttention(torch.nn.Module):
)
use_flash_attention = False
# Filter: bias.
if core_attention_bias_type != "no_bias" or core_attention_bias is not None:
use_flash_attention = False
context_parallel = (self.cp_group is not None and \
get_distributed_world_size(self.cp_group) != 1)
......@@ -2694,11 +2741,6 @@ class DotProductAttention(torch.nn.Module):
if (not _flash_attn_2_3_plus) or context_parallel:
use_flash_attention = False
# Filter: ONNX export.
if is_in_onnx_export_mode():
use_flash_attention = False
use_fused_attention = False
# Filter: Attention mask type.
# attn_mask_type(s) | supported backends
# ------------------------------------------------
......@@ -2714,12 +2756,47 @@ class DotProductAttention(torch.nn.Module):
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
use_unfused_attention = False
# Filter: bias.
global _alibi_cache
if alibi_slopes is not None:
assert (core_attention_bias_type == "alibi"
), "core_attention_bias_type must be alibi in order to use alibi_slopes!"
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type == "alibi":
assert (core_attention_bias is None
), "core_attention_bias must be None when core_attention_bias_type is alibi!"
if (_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_alibi_slopes"] is None):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
if core_attention_bias_type not in ["no_bias", "alibi"] or core_attention_bias is not None:
use_flash_attention = False
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and use_fused_attention and alibi_slopes is not None:
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype)
if (fu_core_attention_bias.shape[0] != 1
or fu_core_attention_bias.shape[1] != query_layer.shape[-2]):
# remove this line when cuDNN adds bwd support for [b, 1, s, s] and [b, h, s, s]
use_fused_attention = False
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
if use_fused_attention:
fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype],
TE_DType[key_layer.dtype],
QKVLayout[qkv_layout],
AttnBiasType[core_attention_bias_type],
AttnBiasType[fu_core_attention_bias_type],
AttnMaskType[attn_mask_type],
self.attention_dropout,
query_layer.shape[-2], # num_attn_heads
......@@ -2736,13 +2813,6 @@ class DotProductAttention(torch.nn.Module):
(not context_parallel or \
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]))
# Filter: Alibi slopes
if alibi_slopes is not None:
use_fused_attention = False
assert (
use_flash_attention
), "Alibi slopes bias is only supported in the FlashAttention backend."
# Filter: determinism.
# backend | deterministic
# ---------------------------------------------------------
......@@ -2771,6 +2841,9 @@ class DotProductAttention(torch.nn.Module):
if use_flash_attention:
if _NVTE_DEBUG:
print("[DotProductAttention]: using flash-attn",_flash_attn_version)
if core_attention_bias_type == "alibi":
alibi_slopes, _ = get_alibi(
query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, alibi_slopes=alibi_slopes)
return self.flash_attention(query_layer,
key_layer,
value_layer,
......@@ -2803,8 +2876,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
......@@ -2821,8 +2894,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
fast_zero_fill=fast_zero_fill,
cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks,
......@@ -2855,7 +2928,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias)
core_attention_bias = core_attention_bias,
alibi_slopes = alibi_slopes)
return self.unfused_attention(query_layer,
key_layer,
value_layer,
......@@ -2865,7 +2939,8 @@ class DotProductAttention(torch.nn.Module):
attn_mask_type = attn_mask_type,
attention_mask = attention_mask,
core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias)
core_attention_bias = core_attention_bias,
alibi_slopes = alibi_slopes)
raise Exception("No dot product attention support for the provided inputs!")
......@@ -3279,6 +3354,7 @@ class MultiheadAttention(torch.nn.Module):
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""
......@@ -3334,6 +3410,10 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv].
It should be 'None' for 'no_bias' and 'alibi' bias types.
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
......@@ -3561,6 +3641,7 @@ class MultiheadAttention(torch.nn.Module):
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
)
......
......@@ -184,9 +184,6 @@ def fused_attn_fwd_qkvpacked(
if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = qkv.size(2) if 'h3d' in qkv_layout else qkv.size(3)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen])
), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (attn_bias.dtype == qkv.dtype
), "attn_bias tensor must be in the same dtype as qkv."
......@@ -479,9 +476,6 @@ def fused_attn_fwd_kvpacked(
if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv."
......@@ -784,9 +778,6 @@ def fused_attn_fwd(
if attn_bias_type not in ["no_bias", "alibi"]:
assert (attn_bias is not None
), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi."
h = q.size(2)
assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv])
), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (attn_bias.dtype == q.dtype
), "attn_bias tensor must be in the same dtype as q and kv."
......
......@@ -286,14 +286,6 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
// create output tensor dQKV
at::Tensor dQKV = torch::empty_like(QKV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors
TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV;
......@@ -358,6 +350,23 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create dBias the same shape as Bias
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS)
&& (bias_type != NVTE_ALIBI)) {
if (nvte_aux_tensor_pack.size >= 2) {
std::vector<int64_t> bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec());
dBias = torch::empty(bias_shape, options);
te_dBias = makeTransformerEngineTensor(dBias);
} else {
dBias = torch::empty({1, static_cast<int64_t>(h),
static_cast<int64_t>(max_seqlen),
static_cast<int64_t>(max_seqlen)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
}
// create cu_seqlens tensorwrappers
auto cu_seqlens_sizes = cu_seqlens.sizes().vec();
std::vector<size_t> cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()};
......@@ -629,14 +638,6 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
at::Tensor dQ = torch::empty_like(Q);
at::Tensor dKV = torch::empty_like(KV);
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors
TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV;
......@@ -721,6 +722,23 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create dBias the same shape as Bias
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS)
&& (bias_type != NVTE_ALIBI)) {
if (nvte_aux_tensor_pack.size >= 2) {
std::vector<int64_t> bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec());
dBias = torch::empty(bias_shape, options);
te_dBias = makeTransformerEngineTensor(dBias);
} else {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
}
// create workspace
TensorWrapper workspace;
......@@ -990,6 +1008,9 @@ std::vector<at::Tensor> fused_attn_bwd(
std::vector<size_t> k_shape{k_sizes.begin(), k_sizes.end()};
auto v_sizes = V.sizes().vec();
std::vector<size_t> v_shape{v_sizes.begin(), v_sizes.end()};
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA);
at::Tensor dQ;
......@@ -1055,22 +1076,10 @@ std::vector<at::Tensor> fused_attn_bwd(
NVTE_ERROR("QKV layout not supported!");
}
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
dBias = torch::empty({1, static_cast<int64_t>(Q.size(-2)),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h_q = q_shape[q_shape.size() - 2];
auto h_kv = k_shape[k_shape.size() - 2];
auto d = q_shape[q_shape.size() - 1];
if (set_zero
&& ((h_q * d) % block_size == 0)
&& ((h_kv * d) % block_size == 0)
......@@ -1165,6 +1174,23 @@ std::vector<at::Tensor> fused_attn_bwd(
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
}
// create dBias the same shape as Bias
at::Tensor dBias;
TensorWrapper te_dBias;
if ((bias_type != NVTE_NO_BIAS)
&& (bias_type != NVTE_ALIBI)) {
if (nvte_aux_tensor_pack.size >= 2) {
std::vector<int64_t> bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec());
dBias = torch::empty(bias_shape, options);
te_dBias = makeTransformerEngineTensor(dBias);
} else {
dBias = torch::empty({1, static_cast<int64_t>(h_q),
static_cast<int64_t>(max_seqlen_q),
static_cast<int64_t>(max_seqlen_kv)}, options);
te_dBias = makeTransformerEngineTensor(dBias);
}
}
// create workspace
TensorWrapper workspace;
......
......@@ -525,6 +525,7 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""
......@@ -583,6 +584,10 @@ class TransformerLayer(torch.nn.Module):
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
alibi_slopes: Optional[torch.Tensor], default = `None`
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads].
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j))
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
......@@ -633,6 +638,7 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
)
......@@ -658,6 +664,7 @@ class TransformerLayer(torch.nn.Module):
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
)
if self.apply_residual_connection_post_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment